In the last post, we built a framework that can define a computation graph, and perform a forward pass. In this post, we will work on the core of most deep learning frameworks: the backwards pass.
Working off the code from the last post, every type of node knew how to compute its own value by evaluating the computation graph, now we just have to teach each one how to pass gradients backwards. Lets think about how this could work by revisiting our computational graph for our linear model:
Our goal is to find the value of ∂Error / ∂w and ∂Error / ∂b. Lets solve this numerically first to get an idea of how we can create a program that does this for us.
Recall that our model is:
y’ = wx + b
Lets rewrite this a bit to help us understand the relation to the computation graph above:
z = wx
y’ = z+ b
And our error is:
Error = (y’ - y)²
So then, by the chain rule, to get ∂Error / dw
∂Error / dw= ∂Error / dy’ × ∂y’ / ∂z × ∂z / ∂w
∂Error / db= ∂Error / dy’ × ∂y’ / ∂b
And now, hopefully we can see how these set of equations map to the graph that we drew out above. We walk up the tree, expanding the chain rule over and over again, from our error back to each variable of interest (w and b).
Some things to note:
While the overall function might seem complex, each term in the expansion has trivial derivatives.
∂Error / dw and ∂Error / db share almost all the same terms. Indeed, this is what we would expect since their path up the computation graph is very similar.
Lets take a look at a quick example before we move on to the implementation, and work out an end to end flow of what our model could behave like:
Lets say we have a single data point: (x = 2, y = 4).
Our model’s variables can be initialized randomly, lets say for this example that w = 0.3, and b = 0.4.
Then our prediction would be:
y’ = wx + b = 0.3 × 2 + 0.4 = 1
Our error would be:
Error = (y’ - y)² = (1 - 4)² = 9
Now that we have our error, lets figure out all the gradients piece by piece:
∂Error / dy’ = 2 × (y’ - y) = 2 × (1 - 4)= -6
∂y’ / ∂z = 1
∂z / ∂w = x = 3
∂z / ∂b = 1
So now, we can compute:
∂Error / ∂w= ∂Error / ∂y’ × ∂y’ / ∂z × ∂z / ∂w = (-6)× (1) × (3) = -18
∂Error / ∂b= ∂Error / ∂y’ × ∂y’ / ∂b = (-6) × (1) = -6
Now that we have these two values, we know how to update w and b to make the prediction a little closer to the correct value. Assuming a learning rate of 0.01
w = w - ∂Error / ∂w = 0.3 - 0.01 ×(-18) = 0.48
b = b - ∂Error / ∂b = 0.4 - 0.01 ×(-6) = 0.46
With these two values, our next prediction will of the same datapoint will be:
y’ = wx + b = 0.48 × 2 + 0.46 = 1.42
Error = (y’ - y)² = (1.42 - 4)² = 6.6564
And voila! Our error is now lower.
Implementation
Now that we saw how this could work, lets actually implement it. We’ll be building off the progress we made from our last post.
For each of our nodes, lets teach it how to take its own derivative in the chain rule expansion, but one thing important to realize is that the derivative is different depending on which direction we walk up the tree. That is, if we have a tree that represented a = b - c:
∂a / ∂b = 1
∂a / ∂c = -1
So every node needs to know how to walk up the left path, and also the right path:
We haven’t included the implementation of SquaredNode in here, since that one doesn’t have a left and right path. We’ll deal with that later.
Now that we know how to walk up the tree, we want to write a function that will be able to recursively walk up the tree and collect the chain rule expansions at each step. Lets implement a method path_to_variables
. We can also fill in this method for SquaredNode
too.
Admittedly this code isn’t the greatest, but I was too lazy to clean this code up. The gist of this is that the path_to_variables
method returns a list of lists, where each list is the path from the cost node to each of the variable nodes.
Now lets finish the full implementation of our Session class:
Now that we have a full implementation for this, lets test it against some data we generate:
And now, lets create a linear model:
Then, lets train our data for 500 epochs and plot our results:
And now, lets build a quadratic model and train it against our generated data:
Yay! Our framework works! In our next post, let’s implement a SigmoidNode or ReLUNode and build a small feed forward neural network.