Lesson 5: Becoming a Backprop Ninja

swole doge style

Starter code

Bolierplate

Imports

import random
import numpy
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

Read in all the words

words = open('../data/names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])
32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

Build the vocabulary of characters and mappings to/from integers

chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)
{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27

Build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):
    X, Y = [], []
    
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # crop and append
    
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y
random.seed(2)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

Xtr, Ytr = build_dataset(words[:n1]) # 80%
Xdev, Ydev = build_dataset(words[n1: n2]) # 10%
Xte, Yte = build_dataset(words[n2:]) # 10%
torch.Size([182481, 3]) torch.Size([182481])
torch.Size([22849, 3]) torch.Size([22849])
torch.Size([22816, 3]) torch.Size([22816])

ok boiler done, now get to the action

Lets get started

Utility function

To compare manual gradients to PyTorch gradients

def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

Initialize the layers

n_embd = 10 # the dimensionality for the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd),             generator = g)

# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size) ** 0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 
# using b1 just for fun, it's useless because of batchnorm

# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1

# BatchNorm parameters
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1

# Note: The parameters here are initialized in non-standard ways
# because sometimes initializing with e.g. all zeros could mask an incorrect
# implementation of the backward pass

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad =True
4137

Construct a minibatch

bs = 32; n = bs
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (bs,) , generator = g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X, Y

An Epoch

Forward Pass
emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors

# Linear layer 1
hprebn = embcat @ W1 + b1

# Batchnorm layer
bnmeani = 1/n * hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff ** 2
bnvar = 1/(n - 1) * (bndiff2).sum(0, keepdim=True) # note: bessel's correction (dividing by n - 1, not n)
bnvar_inv = (bnvar + 1e-5) ** -0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias

# Non linearity
h = torch.tanh(hpreact)

# Linear layer 2
logits = h @ W2 + b2 # output layer

# cross entropy loss ( same as F.cross_entropy(logits, Yb) )
logit_maxes = logits.max(1, keepdim = True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum ** -1 # if I use (1.0 / counts_sum ) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters: p.grad = None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no clear way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
          bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
          embcat, emb]:
    t.retain_grad()
loss.backward()
loss
tensor(3.4738, grad_fn=<NegBackward0>)

Exercise 1:

backprop through the whole thing manually, backpropagating through exactly all of the variables as they are defined in the forward pass above, one by one

C.shape, Xb.shape, emb.shape
(torch.Size([27, 10]), torch.Size([32, 3]), torch.Size([32, 3, 10]))
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n
cmp('logprobs', dlogprobs, logprobs)

dprobs = torch.zeros_like(probs)
dprobs = (1/probs) * dlogprobs
cmp('probs', dprobs, probs)


dcounts_sum_inv =  (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)

dcounts_sum = (-1 * counts_sum ** -2) * dcounts_sum_inv
cmp('counts_sum', dcounts_sum, counts_sum)

dcounts += dcounts_sum
cmp('counts', dcounts, counts)

dnorm_logits = norm_logits.exp() * dcounts
cmp('norm_logits', dnorm_logits, norm_logits)

dlogits = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
cmp('logit_maxes', dlogit_maxes, logit_maxes)

dlogits_temp = torch.zeros_like(logits)
dlogits_temp[range(n), torch.argmax(logits, 1)] = 1
dlogits += dlogits_temp * dlogit_maxes
cmp('logits', dlogits, logits)

dh = dlogits @ W2.T
cmp('h', dh, h)

dW2 = h.T @ dlogits
cmp('W2', dW2, W2)

db2 = dlogits.sum(0)
cmp('b2', db2, b2)

dhpreact = ((1 - h ** 2) * dh)
cmp('hpreact', dhpreact, hpreact)

dbngain = (bnraw * dhpreact).sum(0, keepdims=True)
cmp('bngain', dbngain, bngain)

dbnbias = dhpreact.sum(0, keepdims=True)
cmp('bnbias', dbnbias, bnbias)

dbnraw = bngain * dhpreact
cmp('bnraw', dbnraw, bnraw)

dbnvar_inv = (bndiff * dbnraw).sum(0, keepdims=True)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)

dbndiff = bnvar_inv * dbnraw
cmp('bndiff', dbndiff, bndiff)

dbnvar = (-0.5 * (bnvar + 1e-5) ** -1.5) * dbnvar_inv
cmp('bnvar', dbnvar, bnvar)

dbndiff2 = (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar
cmp('bndiff2', dbndiff2, bndiff2)

dbndiff += (2 * bndiff) * dbndiff2
cmp('bndiff', dbndiff, bndiff)

dbnmeani = (- 1 * dbndiff).sum(0, keepdims=True) 
cmp('bnmeani', dbnmeani, bnmeani)

dhprebn = dbndiff
dhprebn += 1/n * (torch.ones_like(hprebn)) * dbnmeani
cmp('hprebn', dhprebn, hprebn)

dembcat = dhprebn @ W1.T
cmp('embcat', dembcat, embcat)

dW1 = embcat.T @ dhprebn
cmp('W1', dW1, W1)

db1 = dhprebn.sum(0)
cmp('b1', db1, b1)

demb = dembcat.view(emb.shape)
cmp('emb', demb, emb)

dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k, j]
        dC[ix] += demb[k, j]
cmp('C', dC, C)
logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | exact: True  | approximate: True  | maxdiff: 0.0
bnvar_inv       | exact: True  | approximate: True  | maxdiff: 0.0
bndiff          | exact: False | approximate: False | maxdiff: 0.0013724520104005933
bnvar           | exact: True  | approximate: True  | maxdiff: 0.0
bndiff2         | exact: True  | approximate: True  | maxdiff: 0.0
bndiff          | exact: True  | approximate: True  | maxdiff: 0.0
bnmeani         | exact: True  | approximate: True  | maxdiff: 0.0
hprebn          | exact: True  | approximate: True  | maxdiff: 0.0
embcat          | exact: True  | approximate: True  | maxdiff: 0.0
W1              | exact: True  | approximate: True  | maxdiff: 0.0
b1              | exact: True  | approximate: True  | maxdiff: 0.0
emb             | exact: True  | approximate: True  | maxdiff: 0.0
C               | exact: True  | approximate: True  | maxdiff: 0.0
dprobs.shape
torch.Size([32, 27])
probs.shape
torch.Size([32, 27])
counts.shape
torch.Size([32, 27])
dcounts_sum.shape
torch.Size([32, 1])
counts.shape
torch.Size([32, 27])

Exercise 2

backprop through cross_entropy but all in one go to complete this challenge look at the mathematical expression of the losss, take the derivative, simplify the expression, and just write it out

Forward Pass

Before

# logits_maxes = logits.max(1, keepdim = True).values
# norm_logits = logits - logits_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum ** -1 # If (1.0 / counts_sum) then it cannot be exactly backpropagated
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()

Now

loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), 'diff:', (loss_fast - loss).item())
3.4734604358673096 diff: -2.384185791015625e-07

Backprop

dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1
dlogits /= n
cmp('logits', dlogits, logits)
logits          | exact: False | approximate: True  | maxdiff: 8.381903171539307e-09
dlogits.shape
torch.Size([32, 27])