import numpy as np
import math
from math import pi
import torch
from matplotlib import pyplot as plt
from torch.distributions import MultivariateNormal
Meanshift Practice
Without peeking
256)
torch.manual_seed(=5, linewidth=140) torch.set_printoptions(precision
= 6
n_clusters = 250 n_samples
= torch.randn(n_clusters, 2) * 70 - 35 centroids
def sample(c): return MultivariateNormal(c, torch.diag(torch.tensor([5., 5.]))).sample((n_samples,))
= torch.concat([sample(c) for c in centroids], axis=0) data
data.shape
torch.Size([1500, 2])
def plot_data(centroids, data, ax=None):
if ax is None: fig, ax = plt.subplots()
for i, centroid in enumerate(centroids):
= data[i*n_samples:(i+1)*n_samples]
samples 0], samples[:, 1], s = 1)
ax.scatter(samples[:, *centroid, markersize=10, marker="x", color="k", mew=5)
ax.plot(*centroid, markersize=5, marker="x", color="m", mew=3) ax.plot(
plot_data(centroids, data)
def gaussian(dist, bw=2.5): return (1/(math.sqrt(2*pi)*bw))*torch.exp(-0.5*(dist/bw)**2)
= plt.subplots()
fig, ax = torch.linspace(-10, 10, 100)
x ax.plot(x, gaussian(x))
def tri(dist, bw=10): return (-dist + bw).clamp(min=0)
= plt.subplots()
fig, ax = torch.linspace(-50, 50, 100)
x ax.plot(x, tri(x))
= data.clone() X
= X[0] x
- X)**2).sum(1).sqrt()) gaussian(((x
tensor([0.15958, 0.07880, 0.06436, ..., 0.00000, 0.00000, 0.00000])
def one_update(X, weight_func=gaussian):
for i, x in enumerate(X):
= ((x - X)**2).sum(1).sqrt()
dist = weight_func(dist, bw=2.5)
weight = (weight[..., None] * X).sum(0)/weight.sum(0) X[i]
def meanshift(data, weight_func=gaussian):
= data.clone()
X for i in range(5): one_update(X, weight_func)
return X
645 ms ± 520 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)
= meanshift(data) Y
Y
tensor([[ -28.42868, -102.61228],
[ -28.42868, -102.61228],
[ -28.42867, -102.61228],
...,
[ -26.45944, 106.64550],
[ -26.45944, 106.64550],
[ -26.45944, 106.64550]])
+5, Y) plot_data(centroids
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
def animate(d):
if not d: return plot_data(centroids + 5, X, ax=ax)
one_update(X)
ax.clear()+ 5, X, ax=ax) plot_data(centroids
= data.clone()
X= plt.subplots()
fig, ax = FuncAnimation(fig, animate, frames=5, interval=500, repeat=False)
f
plt.close() HTML(f.to_jshtml())
GPU Batched algorithm
= 5
bs = data.clone()
X = X[:bs]
x x.shape, X.shape
(torch.Size([5, 2]), torch.Size([1500, 2]))
= ((x[:,None] - X[None]) ** 2).sum(-1).sqrt()
dist dist.shape, dist
(torch.Size([5, 1500]),
tensor([[ 0.00000, 2.96982, 3.36906, ..., 209.04686, 210.20853, 209.90276],
[ 2.96982, 0.00000, 4.58362, ..., 208.38335, 209.50293, 209.21017],
[ 3.36906, 4.58362, 0.00000, ..., 212.35605, 213.50829, 213.20551],
[ 3.42341, 1.80829, 3.42423, ..., 210.16139, 211.27600, 210.98482],
[ 2.33995, 1.62542, 2.96569, ..., 209.84781, 210.97752, 210.68167]]))
= gaussian(dist)
weight weight
tensor([[0.15958, 0.07880, 0.06436, ..., 0.00000, 0.00000, 0.00000],
[0.07880, 0.15958, 0.02972, ..., 0.00000, 0.00000, 0.00000],
[0.06436, 0.02972, 0.15958, ..., 0.00000, 0.00000, 0.00000],
[0.06249, 0.12285, 0.06246, ..., 0.00000, 0.00000, 0.00000],
[0.10298, 0.12917, 0.07896, ..., 0.00000, 0.00000, 0.00000]])
weight.shape
torch.Size([5, 1500])
@ X).shape (weight
torch.Size([5, 2])
sum(1, keepdims=True).shape weight.
torch.Size([5, 1])
for i in range(0, 10, 2):
print(i)
0
2
4
6
8
def one_update(X, bs):
= len(X)
n for i in range(0, n, bs):
= slice(i, min(i+bs, n))
s = X[s]
x = ((x[:,None] - X[None]) ** 2).sum(-1).sqrt()
dist = gaussian(dist)
weight = (weight @ X)/ (weight.sum(1, keepdims=True)) X[s]
def meanshift(data, bs=16):
= data.clone()
X for i in range(10): one_update(X, bs)
return X
= data.cuda() data
48.2 ms ± 422 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.36 ms ± 4.38 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
X
tensor([[ -30.65150, -102.48758],
[ -27.74988, -101.85477],
[ -30.07228, -105.80648],
...,
[ -26.98898, 106.52718],
[ -23.92055, 107.61317],
[ -24.87365, 107.33564]])
+5, X) plot_data(centroids