In this post, we will be looking at the main concept behind Recurrent Neural Networks, and will be deriving the back-propagation equations in detail.
My Tensorflow implementation of RNN for MNIST can be found - here.
Our focus will more be on getting the intuition and mathematical understanding of the concept.
Recurrent Neural Networks is a class of networks which help is incorporate past states intelligently into decision making.
Why do we need this?
Take for example Google Translate, when it translates a particular word, we need to incorporate context into account. The context in this case can be the words before the current word to be translated. Bidirectional networks take as the context as both the words before and after the current word!
Naive way: A naive way to take context into account will be just to feed states t-4,t-3..t-1 along with t as the input to the network. The problem with this it that
1) Dimensionality of the input is growing linearly with the amount of states we want to take into account during decision making.
2) Ideally, we want to remember what is important for making the decision at the current timestep, and include only that. Adding everything into the state will result in very messy learning process, and will decrease performance.
What is an intelligent solution to this?
RNN Solves this problem by introducing a new state, called hidden state.
This hidden state will contain information i.e the memory of the network at time t, which will be a function of previous states.
When the network outputs something, it takes a weighted combinated of the hidden state, and the input at that time.
Voila! Our decision making now includes context.
How do we learn the weights?
The weights are learned using an algorithm called Backpropogation through time aka BPPT. The goal is the same as our normal feed forward networks. Choose the weights so as to best explain our data.
Hopefully the following writeup will clear things up!
Add caption |
When back-propogating through time, and taking contribution of weights at time t, to error at time t +k, the derivatives of hidden states with respect to weights keep getting multiplied with each other.
The derivative of the tanh function is bounded from -1 to 1. Hence multiplying them with each others results in a number which tends to zero as t increases. Hence, the gradient vanishes. This problem is called the vanishing gradient problem. This problem prevents us from learning long term dependencies, as the contribution of weights at time t will go to zero for gradient at time t+k, when k increases.
One solution to this problem, is LSTM, a new architecture which solves the problem by introducing 'memory cells'. We will talk more in detail about this in a future post.
Overall, RNN helps us include context into decision making by introducing a new state called hidden state. Backpropogation through time is just backpropogation on the same network unrolled for t steps! The longer we backpropogate, the more farther back context we can take into account during decision making.
Hope this post was useful. Cheers :)