import numpy as np
import math
from math import pi
import torch
from matplotlib import pyplot as plt
from torch.distributions import MultivariateNormal
from miniai.lsh import FastLSHMeanshift Using LSH
torch.manual_seed(256)
torch.set_printoptions(precision=5, linewidth=140)n_clusters = 6
n_samples= 250centroids = torch.randn(n_clusters, 2) * 70 - 35def sample(c): return MultivariateNormal(c, torch.diag(torch.tensor([5., 5.]))).sample((n_samples,))data = torch.concat([sample(c) for c in centroids], axis=0)data.shapetorch.Size([1500, 2])
def plot_data(centroids, data, ax=None):
if ax is None: fig, ax = plt.subplots()
for i, centroid in enumerate(centroids):
samples = data[i*n_samples:(i+1)*n_samples]
ax.scatter(samples[:, 0], samples[:, 1], s = 1)
ax.plot(*centroid, markersize=10, marker="x", color="k", mew=5)
ax.plot(*centroid, markersize=5, marker="x", color="m", mew=3)plot_data(centroids, data)
X = data.clone()x = X[0]lsh = FastLSH(2, 2, 3)Xh = lsh.hashing(X)def one_update(X):
for i, x in enumerate(X):
dist, idx = lsh.query_neigbours(x[None], X, Xh, 150)
# removing weighting because lsh already returns in a sorted order
# and we are requesting for nearest neighbours
X[i] = X[idx].sum(0)/len(idx)def meanshift(data):
X = data.clone()
for i in range(5): one_update(X)
return X1.56 s ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
Y = meanshift(data)Ytensor([[ -28.38632, -102.51652],
[ -28.34384, -102.53816],
[ -28.34400, -102.53811],
...,
[ -26.40479, 106.70017],
[ -26.35939, 106.70464],
[ -26.35950, 106.70462]])
plot_data(centroids+5, Y)
from matplotlib.animation import FuncAnimation
from IPython.display import HTMLdef animate(d):
if not d: return plot_data(centroids + 5, X, ax=ax)
one_update(X)
ax.clear()
plot_data(centroids + 5, X, ax=ax)X= data.clone()
fig, ax = plt.subplots()
f = FuncAnimation(fig, animate, frames=5, interval=500, repeat=False)
plt.close()
HTML(f.to_jshtml())