Efficient Local Sensitive Hashing: Loop-Free Implementation

My implementation of Local Sensitive Hashing with zero loop - only tensor operations
Author

Anubhav Maity

Published

September 11, 2023

Locality Sensitive Hashing (LSH)

In this blog post I have tried explaining and implementing local sensitive hashing with zero loops i.e. only tensor operations. Let’s dive in.

Local sensitive hashing (LSH) is a technique used in approximate nearest neighbor search and similarity-based retrieval tasks. LSH helps in efficiently finding similar items or reducing the search space for similarity queries.

LSH works by hashing similar items into the same or nearby hash buckets with a high probability. It operates on the principle that if two items are similar, they are likely to collide (hash to the same bucket) under a certain hash function.

LSH is “local sensitive” because it ensures that nearby items have a higher probability of being hashed into the same bucket, while items that are far apart have a lower probability of colliding. This property allows for efficient pruning of the search space, as we can focus the search on the items within the same hash buckets.

There are different types of LSH algorithms designed for various data types and similarity measures. Some common examples include MinHash for document similarity, SimHash for binary data, and L2-LSH for Euclidean distance-based similarity.

LSH is particularly useful in scenarios where traditional exact search methods become impractical due to high-dimensional data or large dataset sizes. It allows for approximate similarity search with reduced computational complexity, making it a valuable tool in various applications like recommendation systems, image retrieval, and DNA sequence matching.

#!pip install fastcore
import matplotlib.pyplot as plt
import torch
from fastcore.all import patch
torch.manual_seed(42)
torch.set_printoptions(precision=3, linewidth=140, sci_mode=False)

I defined the class FastLSH below which takes the following arguments - dimensions: This is the number of features in the dataset. The higher the number of dimensions, the more complex the LSH algorithm will be. - hash_length: This is the length of the hash value. The longer the hash value, the more accurate the LSH algorithm will be. However, a longer hash value will also take longer to compute. - number_hash_tables: This is the number of hash tables that are used by the LSH algorithm. The more hash tables, the more likely it is that two similar items will be hashed to the same table. However, a larger number of hash tables will also take longer to search. - hash_table: This is a data structure that stores the hash values of the items in the dataset. The hash table is used to quickly find items that have similar hash values.

class FastLSH:
    def __init__(self, dim, nht, hl):
        self.dimensions = dim # dimensions of the data
        self.num_hash_tables = nht # number of hash tables
        self.hash_length = hl # hash length
        self.hash_table = torch.randn(self.num_hash_tables, self.hash_length, self.dimensions) # the hashtable
fastlsh = FastLSH(dim=2, nht=5, hl=10)
fastlsh.hash_table.shape
torch.Size([5, 10, 2])
data = torch.randn(150_000, 2) # data
query = data[0][None] # query, adding a unit axis at the start
data.shape, query.shape
(torch.Size([150000, 2]), torch.Size([1, 2]))

The purpose of the following hashing code is to apply hash functions to each data point in the input tensor. It performs a cosine similarity between the query and the hashtable, generating hash codes for every data point in the query. The resulting tensor contains these hash codes.

The patch decorator patches the function to the LSH class. More about patch here

@patch
def hashing(self:FastLSH, query): 
    return (((query[:, None, None]) * self.hash_table).sum(-1) >= 0).long()

Let’s utilize the hashing function mentioned above to hash both the query and data points.

data_hash = fastlsh.hashing(data)
data_hash.shape
torch.Size([150000, 5, 10])

Based on the obtained data_hash.shape from the above, we can observe that each of the 150_000 data points has been hashed using 5 hash functions, resulting in hash codes of length 10 for each data point.

query_hash = fastlsh.hashing(query)
query_hash.shape
torch.Size([1, 5, 10])

Now, let’s proceed to obtain the indexes of data where the hash code of each data point matches the hash code of the query point.

To determine if the hash codes are the same, we can begin by calculating the dot product along the last axis and dividing it by the sum of the corresponding axis in the data. If the resulting values are equal to 1, it indicates that the hash codes are the same. Otherwise, if the values differ from 1, it implies that the hash codes are not the same.

(query_hash * data_hash).sum(-1)
tensor([[4, 5, 2, 6, 5],
        [1, 3, 2, 2, 3],
        [4, 3, 0, 4, 3],
        ...,
        [2, 3, 2, 5, 4],
        [1, 3, 2, 3, 3],
        [0, 1, 2, 2, 1]])
data_hash.sum(-1)
tensor([[ 4,  5,  2,  6,  5],
        [ 4,  8,  9,  6,  6],
        [ 6,  3,  0,  4,  5],
        ...,
        [ 5,  6,  4,  8,  4],
        [ 4,  7,  7,  7,  6],
        [ 4,  6, 10,  6,  4]])
(query_hash * data_hash).shape
torch.Size([150000, 5, 10])
(query_hash * data_hash).sum(-1) / data_hash.sum(-1) 
tensor([[1.000, 1.000, 1.000, 1.000, 1.000],
        [0.250, 0.375, 0.222, 0.333, 0.500],
        [0.667, 1.000,   nan, 1.000, 0.600],
        ...,
        [0.400, 0.500, 0.500, 0.625, 1.000],
        [0.250, 0.429, 0.286, 0.429, 0.500],
        [0.000, 0.167, 0.200, 0.333, 0.250]])
result = ( (query_hash * data_hash).sum(-1) / data_hash.sum(-1) ) == 1
result
tensor([[ True,  True,  True,  True,  True],
        [False, False, False, False, False],
        [False,  True, False,  True, False],
        ...,
        [False, False, False, False,  True],
        [False, False, False, False, False],
        [False, False, False, False, False]])

We can obtain the indices where the values are True using the following code.

result_indices = torch.nonzero(torch.any(result, dim=1)).flatten()
result_indices.shape
torch.Size([50337])
data[result_indices].shape
torch.Size([50337, 2])

Now that we have obtained the indices, let’s proceed to compute the Euclidean distance (L2 norm).

To compute the Euclidean distance, we don’t need to calculate the distance from every one of the 150_000 points. Instead, we only need to compute the distance from the points where the hash codes are the same, as we have already determined from the indices obtained.

((query - data[result_indices])**2).sum(-1).sqrt().shape
torch.Size([50337])

We can consolidate all of the aforementioned operations into a single function by defining the following function

@patch
def query_neigbours(self:FastLSH, query, data, data_hash, neighbours=10):
    query_hash = self.hashing(query)
    result = ( (query_hash * data_hash).sum(-1) / data_hash.sum(-1) ) == 1
    result_indices = torch.nonzero(torch.any(result, dim=1)).flatten()
    
    dist = ((query - data[result_indices]) ** 2).sum(-1).sqrt()
    sorted_dist, idx = torch.sort(dist)
    
    
    return sorted_dist[:neighbours], result_indices[idx[:neighbours]]
query = data[0][None]; query.shape
torch.Size([1, 2])
data_hash = fastlsh.hashing(data)
fastlsh.query_neigbours(query, data, data_hash, 10)
(tensor([0.000, 0.004, 0.004, 0.005, 0.007, 0.007, 0.009, 0.009, 0.009, 0.010]),
 tensor([     0,  16866,  13708,  29511,  37183,  31814,   9378, 122251, 131806,   4483]))
%timeit -n 5 _=fastlsh.query_neigbours(query, data, data_hash, 10)
11.9 ms ± 643 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)

GPU

I’m currently unable to determine how to pass batches of query points for computing the hash and obtaining the distances. As a temporary solution, I’m passing a single query point and the data as CUDA. Let’s see if this optimization yields better results compared to the CPU version mentioned above.

data_cuda = data.cuda()
query_cuda = query.cuda()
fastlsh.hash_table = fastlsh.hash_table.cuda()
data_hash_cuda = fastlsh.hashing(data_cuda)
fastlsh.query_neigbours(query_cuda, data_cuda, data_hash_cuda, 10)
(tensor([0.000, 0.004, 0.004, 0.005, 0.007, 0.007, 0.009, 0.009, 0.009, 0.010], device='cuda:0'),
 tensor([     0,  16866,  13708,  29511,  37183,  31814,   9378, 122251, 131806,   4483], device='cuda:0'))
%timeit -n 5 _=fastlsh.query_neigbours(query_cuda, data_cuda, data_hash_cuda, 10)
739 µs ± 30.9 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)

The GPU computation is 15x faster