import random
import numpy
import torch
import torch.nn.functional as F
import matplotlib.pyplot as pltLesson 5: Becoming a Backprop Ninja
swole doge style
Starter code
Bolierplate
Imports
Read in all the words
words = open('../data/names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']Build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size){1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27Build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?
def build_dataset(words):
    X, Y = [], []
    
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # crop and append
    
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Yrandom.seed(2)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))
Xtr, Ytr = build_dataset(words[:n1]) # 80%
Xdev, Ydev = build_dataset(words[n1: n2]) # 10%
Xte, Yte = build_dataset(words[n2:]) # 10%torch.Size([182481, 3]) torch.Size([182481])
torch.Size([22849, 3]) torch.Size([22849])
torch.Size([22816, 3]) torch.Size([22816])ok boiler done, now get to the action
Lets get started
Utility function
To compare manual gradients to PyTorch gradients
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')Initialize the layers
n_embd = 10 # the dimensionality for the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd),             generator = g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size) ** 0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 
# using b1 just for fun, it's useless because of batchnorm
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1
# Note: The parameters here are initialized in non-standard ways
# because sometimes initializing with e.g. all zeros could mask an incorrect
# implementation of the backward pass
parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad =True4137Construct a minibatch
bs = 32; n = bs
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (bs,) , generator = g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X, YAn Epoch
Forward Pass
emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1
# Batchnorm layer
bnmeani = 1/n * hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff ** 2
bnvar = 1/(n - 1) * (bndiff2).sum(0, keepdim=True) # note: bessel's correction (dividing by n - 1, not n)
bnvar_inv = (bnvar + 1e-5) ** -0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non linearity
h = torch.tanh(hpreact)
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss ( same as F.cross_entropy(logits, Yb) )
logit_maxes = logits.max(1, keepdim = True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum ** -1 # if I use (1.0 / counts_sum ) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()
# PyTorch backward pass
for p in parameters: p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no clear way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
          bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
          embcat, emb]:
    t.retain_grad()
loss.backward()
losstensor(3.4738, grad_fn=<NegBackward0>)Exercise 1:
backprop through the whole thing manually, backpropagating through exactly all of the variables as they are defined in the forward pass above, one by one
C.shape, Xb.shape, emb.shape(torch.Size([27, 10]), torch.Size([32, 3]), torch.Size([32, 3, 10]))dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n
cmp('logprobs', dlogprobs, logprobs)
dprobs = torch.zeros_like(probs)
dprobs = (1/probs) * dlogprobs
cmp('probs', dprobs, probs)
dcounts_sum_inv =  (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
dcounts_sum = (-1 * counts_sum ** -2) * dcounts_sum_inv
cmp('counts_sum', dcounts_sum, counts_sum)
dcounts += dcounts_sum
cmp('counts', dcounts, counts)
dnorm_logits = norm_logits.exp() * dcounts
cmp('norm_logits', dnorm_logits, norm_logits)
dlogits = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
dlogits_temp = torch.zeros_like(logits)
dlogits_temp[range(n), torch.argmax(logits, 1)] = 1
dlogits += dlogits_temp * dlogit_maxes
cmp('logits', dlogits, logits)
dh = dlogits @ W2.T
cmp('h', dh, h)
dW2 = h.T @ dlogits
cmp('W2', dW2, W2)
db2 = dlogits.sum(0)
cmp('b2', db2, b2)
dhpreact = ((1 - h ** 2) * dh)
cmp('hpreact', dhpreact, hpreact)
dbngain = (bnraw * dhpreact).sum(0, keepdims=True)
cmp('bngain', dbngain, bngain)
dbnbias = dhpreact.sum(0, keepdims=True)
cmp('bnbias', dbnbias, bnbias)
dbnraw = bngain * dhpreact
cmp('bnraw', dbnraw, bnraw)
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdims=True)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
dbndiff = bnvar_inv * dbnraw
cmp('bndiff', dbndiff, bndiff)
dbnvar = (-0.5 * (bnvar + 1e-5) ** -1.5) * dbnvar_inv
cmp('bnvar', dbnvar, bnvar)
dbndiff2 = (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar
cmp('bndiff2', dbndiff2, bndiff2)
dbndiff += (2 * bndiff) * dbndiff2
cmp('bndiff', dbndiff, bndiff)
dbnmeani = (- 1 * dbndiff).sum(0, keepdims=True) 
cmp('bnmeani', dbnmeani, bnmeani)
dhprebn = dbndiff
dhprebn += 1/n * (torch.ones_like(hprebn)) * dbnmeani
cmp('hprebn', dhprebn, hprebn)
dembcat = dhprebn @ W1.T
cmp('embcat', dembcat, embcat)
dW1 = embcat.T @ dhprebn
cmp('W1', dW1, W1)
db1 = dhprebn.sum(0)
cmp('b1', db1, b1)
demb = dembcat.view(emb.shape)
cmp('emb', demb, emb)
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k, j]
        dC[ix] += demb[k, j]
cmp('C', dC, C)logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | exact: True  | approximate: True  | maxdiff: 0.0
bnvar_inv       | exact: True  | approximate: True  | maxdiff: 0.0
bndiff          | exact: False | approximate: False | maxdiff: 0.0013724520104005933
bnvar           | exact: True  | approximate: True  | maxdiff: 0.0
bndiff2         | exact: True  | approximate: True  | maxdiff: 0.0
bndiff          | exact: True  | approximate: True  | maxdiff: 0.0
bnmeani         | exact: True  | approximate: True  | maxdiff: 0.0
hprebn          | exact: True  | approximate: True  | maxdiff: 0.0
embcat          | exact: True  | approximate: True  | maxdiff: 0.0
W1              | exact: True  | approximate: True  | maxdiff: 0.0
b1              | exact: True  | approximate: True  | maxdiff: 0.0
emb             | exact: True  | approximate: True  | maxdiff: 0.0
C               | exact: True  | approximate: True  | maxdiff: 0.0dprobs.shapetorch.Size([32, 27])probs.shapetorch.Size([32, 27])counts.shapetorch.Size([32, 27])dcounts_sum.shapetorch.Size([32, 1])counts.shapetorch.Size([32, 27])Exercise 2
backprop through cross_entropy but all in one go to complete this challenge look at the mathematical expression of the losss, take the derivative, simplify the expression, and just write it out
Forward Pass
Before
# logits_maxes = logits.max(1, keepdim = True).values
# norm_logits = logits - logits_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum ** -1 # If (1.0 / counts_sum) then it cannot be exactly backpropagated
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()Now
loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), 'diff:', (loss_fast - loss).item())3.4734604358673096 diff: -2.384185791015625e-07Backprop
dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1
dlogits /= n
cmp('logits', dlogits, logits)logits          | exact: False | approximate: True  | maxdiff: 8.381903171539307e-09dlogits.shapetorch.Size([32, 27])