import random
import numpy
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
Lesson 5: Becoming a Backprop Ninja
swole doge style
Starter code
Bolierplate
Imports
Read in all the words
= open('../data/names.txt', 'r').read().splitlines()
words print(len(words))
print(max(len(w) for w in words))
print(words[:8])
32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']
Build the vocabulary of characters and mappings to/from integers
= sorted(list(set(''.join(words))))
chars = {s: i+1 for i, s in enumerate(chars)}
stoi '.'] = 0
stoi[= {i:s for s, i in stoi.items()}
itos = len(itos)
vocab_size print(itos)
print(vocab_size)
{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27
Build the dataset
= 3 # context length: how many characters do we take to predict the next one?
block_size
def build_dataset(words):
= [], []
X, Y
for w in words:
= [0] * block_size
context for ch in w + '.':
= stoi[ch]
ix
X.append(context)
Y.append(ix)= context[1:] + [ix] # crop and append
context
= torch.tensor(X)
X = torch.tensor(Y)
Y print(X.shape, Y.shape)
return X, Y
2)
random.seed(
random.shuffle(words)= int(0.8 * len(words))
n1 = int(0.9 * len(words))
n2
= build_dataset(words[:n1]) # 80%
Xtr, Ytr = build_dataset(words[n1: n2]) # 10%
Xdev, Ydev = build_dataset(words[n2:]) # 10% Xte, Yte
torch.Size([182481, 3]) torch.Size([182481])
torch.Size([22849, 3]) torch.Size([22849])
torch.Size([22816, 3]) torch.Size([22816])
ok boiler done, now get to the action
Lets get started
Utility function
To compare manual gradients to PyTorch gradients
def cmp(s, dt, t):
= torch.all(dt == t.grad).item()
ex = torch.allclose(dt, t.grad)
app = (dt - t.grad).abs().max().item()
maxdiff print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')
Initialize the layers
= 10 # the dimensionality for the character embedding vectors
n_embd = 64 # the number of neurons in the hidden layer of the MLP
n_hidden
= torch.Generator().manual_seed(2147483647)
g = torch.randn((vocab_size, n_embd), generator = g)
C
# Layer 1
= torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size) ** 0.5)
W1 = torch.randn(n_hidden, generator=g) * 0.1
b1 # using b1 just for fun, it's useless because of batchnorm
# Layer 2
= torch.randn((n_hidden, vocab_size), generator=g) * 0.1
W2 = torch.randn(vocab_size, generator=g) * 0.1
b2
# BatchNorm parameters
= torch.randn((1, n_hidden)) * 0.1 + 1.0
bngain = torch.randn((1, n_hidden)) * 0.1
bnbias
# Note: The parameters here are initialized in non-standard ways
# because sometimes initializing with e.g. all zeros could mask an incorrect
# implementation of the backward pass
= [C, W1, b1, W2, b2, bngain, bnbias]
parameters print(sum(p.nelement() for p in parameters))
for p in parameters:
=True p.requires_grad
4137
Construct a minibatch
= 32; n = bs
bs # construct a minibatch
= torch.randint(0, Xtr.shape[0], (bs,) , generator = g)
ix = Xtr[ix], Ytr[ix] # batch X, Y Xb, Yb
An Epoch
Forward Pass
= C[Xb] # embed the characters into vectors
emb = emb.view(emb.shape[0], -1) # concatenate the vectors
embcat
# Linear layer 1
= embcat @ W1 + b1
hprebn
# Batchnorm layer
= 1/n * hprebn.sum(0, keepdim=True)
bnmeani = hprebn - bnmeani
bndiff = bndiff ** 2
bndiff2 = 1/(n - 1) * (bndiff2).sum(0, keepdim=True) # note: bessel's correction (dividing by n - 1, not n)
bnvar = (bnvar + 1e-5) ** -0.5
bnvar_inv = bndiff * bnvar_inv
bnraw = bngain * bnraw + bnbias
hpreact
# Non linearity
= torch.tanh(hpreact)
h
# Linear layer 2
= h @ W2 + b2 # output layer
logits
# cross entropy loss ( same as F.cross_entropy(logits, Yb) )
= logits.max(1, keepdim = True).values
logit_maxes = logits - logit_maxes # subtract max for numerical stability
norm_logits = norm_logits.exp()
counts = counts.sum(1, keepdims=True)
counts_sum = counts_sum ** -1 # if I use (1.0 / counts_sum ) instead then I can't get backprop to be bit exact...
counts_sum_inv = counts * counts_sum_inv
probs = probs.log()
logprobs = -logprobs[range(n), Yb].mean()
loss
# PyTorch backward pass
for p in parameters: p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no clear way
norm_logits, logit_maxes, logits, h, hpreact, bnraw,
bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
embcat, emb]:
t.retain_grad()
loss.backward() loss
tensor(3.4738, grad_fn=<NegBackward0>)
Exercise 1:
backprop through the whole thing manually, backpropagating through exactly all of the variables as they are defined in the forward pass above, one by one
C.shape, Xb.shape, emb.shape
(torch.Size([27, 10]), torch.Size([32, 3]), torch.Size([32, 3, 10]))
= torch.zeros_like(logprobs)
dlogprobs range(n), Yb] = -1.0/n
dlogprobs[cmp('logprobs', dlogprobs, logprobs)
= torch.zeros_like(probs)
dprobs = (1/probs) * dlogprobs
dprobs cmp('probs', dprobs, probs)
= (counts * dprobs).sum(1, keepdim=True)
dcounts_sum_inv = counts_sum_inv * dprobs
dcounts cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
= (-1 * counts_sum ** -2) * dcounts_sum_inv
dcounts_sum cmp('counts_sum', dcounts_sum, counts_sum)
+= dcounts_sum
dcounts cmp('counts', dcounts, counts)
= norm_logits.exp() * dcounts
dnorm_logits cmp('norm_logits', dnorm_logits, norm_logits)
= dnorm_logits.clone()
dlogits = (-dnorm_logits).sum(1, keepdim=True)
dlogit_maxes cmp('logit_maxes', dlogit_maxes, logit_maxes)
= torch.zeros_like(logits)
dlogits_temp range(n), torch.argmax(logits, 1)] = 1
dlogits_temp[+= dlogits_temp * dlogit_maxes
dlogits cmp('logits', dlogits, logits)
= dlogits @ W2.T
dh cmp('h', dh, h)
= h.T @ dlogits
dW2 cmp('W2', dW2, W2)
= dlogits.sum(0)
db2 cmp('b2', db2, b2)
= ((1 - h ** 2) * dh)
dhpreact cmp('hpreact', dhpreact, hpreact)
= (bnraw * dhpreact).sum(0, keepdims=True)
dbngain cmp('bngain', dbngain, bngain)
= dhpreact.sum(0, keepdims=True)
dbnbias cmp('bnbias', dbnbias, bnbias)
= bngain * dhpreact
dbnraw cmp('bnraw', dbnraw, bnraw)
= (bndiff * dbnraw).sum(0, keepdims=True)
dbnvar_inv cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
= bnvar_inv * dbnraw
dbndiff cmp('bndiff', dbndiff, bndiff)
= (-0.5 * (bnvar + 1e-5) ** -1.5) * dbnvar_inv
dbnvar cmp('bnvar', dbnvar, bnvar)
= (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar
dbndiff2 cmp('bndiff2', dbndiff2, bndiff2)
+= (2 * bndiff) * dbndiff2
dbndiff cmp('bndiff', dbndiff, bndiff)
= (- 1 * dbndiff).sum(0, keepdims=True)
dbnmeani cmp('bnmeani', dbnmeani, bnmeani)
= dbndiff
dhprebn += 1/n * (torch.ones_like(hprebn)) * dbnmeani
dhprebn cmp('hprebn', dhprebn, hprebn)
= dhprebn @ W1.T
dembcat cmp('embcat', dembcat, embcat)
= embcat.T @ dhprebn
dW1 cmp('W1', dW1, W1)
= dhprebn.sum(0)
db1 cmp('b1', db1, b1)
= dembcat.view(emb.shape)
demb cmp('emb', demb, emb)
= torch.zeros_like(C)
dC for k in range(Xb.shape[0]):
for j in range(Xb.shape[1]):
= Xb[k, j]
ix += demb[k, j]
dC[ix] cmp('C', dC, C)
logprobs | exact: True | approximate: True | maxdiff: 0.0
probs | exact: True | approximate: True | maxdiff: 0.0
counts_sum_inv | exact: True | approximate: True | maxdiff: 0.0
counts_sum | exact: True | approximate: True | maxdiff: 0.0
counts | exact: True | approximate: True | maxdiff: 0.0
norm_logits | exact: True | approximate: True | maxdiff: 0.0
logit_maxes | exact: True | approximate: True | maxdiff: 0.0
logits | exact: True | approximate: True | maxdiff: 0.0
h | exact: True | approximate: True | maxdiff: 0.0
W2 | exact: True | approximate: True | maxdiff: 0.0
b2 | exact: True | approximate: True | maxdiff: 0.0
hpreact | exact: True | approximate: True | maxdiff: 0.0
bngain | exact: True | approximate: True | maxdiff: 0.0
bnbias | exact: True | approximate: True | maxdiff: 0.0
bnraw | exact: True | approximate: True | maxdiff: 0.0
bnvar_inv | exact: True | approximate: True | maxdiff: 0.0
bndiff | exact: False | approximate: False | maxdiff: 0.0013724520104005933
bnvar | exact: True | approximate: True | maxdiff: 0.0
bndiff2 | exact: True | approximate: True | maxdiff: 0.0
bndiff | exact: True | approximate: True | maxdiff: 0.0
bnmeani | exact: True | approximate: True | maxdiff: 0.0
hprebn | exact: True | approximate: True | maxdiff: 0.0
embcat | exact: True | approximate: True | maxdiff: 0.0
W1 | exact: True | approximate: True | maxdiff: 0.0
b1 | exact: True | approximate: True | maxdiff: 0.0
emb | exact: True | approximate: True | maxdiff: 0.0
C | exact: True | approximate: True | maxdiff: 0.0
dprobs.shape
torch.Size([32, 27])
probs.shape
torch.Size([32, 27])
counts.shape
torch.Size([32, 27])
dcounts_sum.shape
torch.Size([32, 1])
counts.shape
torch.Size([32, 27])
Exercise 2
backprop through cross_entropy but all in one go to complete this challenge look at the mathematical expression of the losss, take the derivative, simplify the expression, and just write it out
Forward Pass
Before
# logits_maxes = logits.max(1, keepdim = True).values
# norm_logits = logits - logits_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum ** -1 # If (1.0 / counts_sum) then it cannot be exactly backpropagated
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()
Now
= F.cross_entropy(logits, Yb)
loss_fast print(loss_fast.item(), 'diff:', (loss_fast - loss).item())
3.4734604358673096 diff: -2.384185791015625e-07
Backprop
= F.softmax(logits, 1)
dlogits range(n), Yb] -= 1
dlogits[/= n
dlogits cmp('logits', dlogits, logits)
logits | exact: False | approximate: True | maxdiff: 8.381903171539307e-09
dlogits.shape
torch.Size([32, 27])