import numpy as np
import math
from math import pi
import torch
from matplotlib import pyplot as plt
from torch.distributions import MultivariateNormal
from miniai.lsh import FastLSH
Meanshift Using LSH
256)
torch.manual_seed(=5, linewidth=140) torch.set_printoptions(precision
= 6
n_clusters = 250 n_samples
= torch.randn(n_clusters, 2) * 70 - 35 centroids
def sample(c): return MultivariateNormal(c, torch.diag(torch.tensor([5., 5.]))).sample((n_samples,))
= torch.concat([sample(c) for c in centroids], axis=0) data
data.shape
torch.Size([1500, 2])
def plot_data(centroids, data, ax=None):
if ax is None: fig, ax = plt.subplots()
for i, centroid in enumerate(centroids):
= data[i*n_samples:(i+1)*n_samples]
samples 0], samples[:, 1], s = 1)
ax.scatter(samples[:, *centroid, markersize=10, marker="x", color="k", mew=5)
ax.plot(*centroid, markersize=5, marker="x", color="m", mew=3) ax.plot(
plot_data(centroids, data)
= data.clone() X
= X[0] x
= FastLSH(2, 2, 3) lsh
= lsh.hashing(X) Xh
def one_update(X):
for i, x in enumerate(X):
= lsh.query_neigbours(x[None], X, Xh, 150)
dist, idx # removing weighting because lsh already returns in a sorted order
# and we are requesting for nearest neighbours
= X[idx].sum(0)/len(idx) X[i]
def meanshift(data):
= data.clone()
X for i in range(5): one_update(X)
return X
1.56 s ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
= meanshift(data) Y
Y
tensor([[ -28.38632, -102.51652],
[ -28.34384, -102.53816],
[ -28.34400, -102.53811],
...,
[ -26.40479, 106.70017],
[ -26.35939, 106.70464],
[ -26.35950, 106.70462]])
+5, Y) plot_data(centroids
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
def animate(d):
if not d: return plot_data(centroids + 5, X, ax=ax)
one_update(X)
ax.clear()+ 5, X, ax=ax) plot_data(centroids
= data.clone()
X= plt.subplots()
fig, ax = FuncAnimation(fig, animate, frames=5, interval=500, repeat=False)
f
plt.close() HTML(f.to_jshtml())