I recently reacquainted myself with the chain rule. As a reminder, if you, like me, have not touched differentiation for many years, the chain rule is used to differentiate the composition of two differentiable functions. The beauty of this is that it allows the differentiation of very complex functions by breaking it down into its many smaller composite functions.
As it happens the chain rule also plays a pretty big part in machine learning. In machine learning, we aim to optimise a function. We traditionally do this by trying to reduce a loss value, a term that dictates how well our function fits the distribution. We can work out how we should adjust the parameter of the function using a process of gradient descent. The idea here is that by differentiating our function we can work out how the loss term should change with respect to the parameters and inputs and as the goal is to reduce the loss term, how we should alter the parameters to achieve this.
In my research, Gaussian Splatting, our function is fixed and we optimise the inputs. Through this we differentiate the function to calculate how the loss term changes with respect to the inputs. Take our two functions, the “render function” and the “loss function”.
To understand how to adjust our inputs x, we need to determine how the loss would change with respect to the inputs (dx/DL). We can rearrange this as:
Wait a minute, this is just the composition of two differentiable functions and that means we can apply the chain rule, where the derivative of the composition is equal to the product of the derivative of the inside function and the derivative of the outside function.
We can apply this methodology to code. Often a forward pass, the inference, will include several consecutive operations where the output of the previous operation acts as the input to the next operation. Take the code
This can written as the following function:
Or:
Now it might be a bit more obvious that this is just another application of the chain rule. If we want to know how y changes with respect to x (dy/dx), we can apply the chain rule to get:
As we can differentate the invidual functions g,f and h with respect to their inputs, we can work out that:
To summarise, if we have some larger function that is the composite of many smaller functions that we can easily differentiate, then we can differentiate our larger function using the chain rule. In programming terms this means that we can take some function consisting of several operations and easily differentiate it. Libraries such as Pytorch will perform this operation for you, allowing you to just write the forward pass and the backwards pass is calculated behind the scenes (using the chain rule). Let’s now take a look at how we would approach slightly larger functions.
When functions get more complex, understanding the flow of the gradient can be slightly overwhelming. This is where computational graphs come in. They provide a visual representation of the operations and allow us to visually work back and understand the flow of gradient. Let's start with a new function:
We now have three inputs, and we want to understand how the output of f changes with respect to each of these individual inputs. Lets first visualise this function as a computational graph.
Now that we have our graph we can work backwards along the function to calculate the gradient flow. We will start at the right hand side, where the starting gradient is the output with respect to f, df/df. To work out the gradient of dz/df we can use the product rule.
This rule tells us that the flow of gradient across a product is equal to the incoming gradient, df/df, multiplied by the opposing branch value, q = x + y. As df/df is equal to one, we now know the gradient of df/dz is equal to x+y.
To calculate the gradient for the other inputs we need to pass the gradient down the top branch. We can use q as an intermediate value. Using the product rule again we can work out that df/dq is equal to the product of df/df, the incoming gradient, and z, the value of the opposite branch. Once again we can use the chain rule to determin how the gradient should be passed back.
We have already calculated df/dq and we can calculate the dq/dy is equal to one, this means that df/dy is just equal to df/df * z. In practice whenever there is an addition in the computation graph this can just be interpreted as passing the incomming gradient directly back. We can apply the same logic to calculate df/dx.
This approach scales up to much larger functions, and if you follow this approach of propagating the gradient back one node at a time following the two basic rules then it can be relatively simple to calculate the derivative of even some of the most complex functions. Let's quickly work through a more complex case before we deal with recursion.
Lets take the function:
We can draw a computational graph like so:
Now if we follow the same process we can calculate the gradient of the output with respect to each of the inputs.