CLOSE ✕
Get in Touch
Thank you for your interest! Please fill out the form below if you would like to work together.
(If you are looking for my resume, click here)

Thank you! Your submission has been received!

Oops! Something went wrong while submitting the form

The LSTM Reference Card

Greg Condit
|
Resource
|
Apr 10, 2019

Summary

The goal of The LSTM Reference Card is to demonstrate how an LSTM Forward pass works using just vanilla Python and NumPy. Studying these simple functions with the diagram above will result in a strong intuition for how and why LSTM networks work. This exercise does not cover backpropagation; the focus is on understanding how the cell uses prior events to make predictions

Functions for an LSTM Forward Pass using NumPy

<p>CODE: https://gist.github.com/conditg/0b59c24fea92db9a13a70af92f8bd6fd.js</p>

Example Usage and Comparison to PyTorch Output

To demonstrate how an LSTM makes predictions, below is a small LSTM network. We'll allow PyTorch to randomly initialize the weights, but they could really be initialized any way - the point here is just to ensure that the PyTorch LSTM and our NumPy LSTM both use the same starting weights so that the outputs can be compared.

Since our desired output size does not typically match the size of the hidden state, we'll also add a fully connected layer to the network that receives the LSTM output and returns our desired output size.

<p>CODE: https://gist.github.com/conditg/7bb8a70e187dcafaa21a2dcde24b21b6.js</p>

Next, initialize an LSTM model in PyTorch - also with that final fully connected layer - , and examine the state dictionary to see the weights it initialized:

<p>CODE: https://gist.github.com/conditg/df41a6a7327bbafc81d64b990beec057.js</p>

<p>CODE: https://gist.github.com/conditg/b66c58d17fcb9d3413fe06fb3e2d50f5.js</p>

Don't get overwhelmed! The PyTorch documentation explains all we need to break this down:

  • The weights for each gate in are in this order: ignore, forget, learn, output
  • keys with 'ih' in the name are the weights/biases for the input, or Wx_ and Bx_
  • keys with 'hh' in the name are the weights/biases for the hidden state, or Wh_ and Bh_

We can therefore extract the weights for the NumPy LSTM to use in this way:

<p>CODE: https://gist.github.com/conditg/77d09e506ada2a5cef48172959ec0239.js</p>

Now, we have two networks - 1 PyTorch, 1 NumPy -with access to the same starting weights. We'll put some time series data through each to ensure they are identical. To do a forward pass with our network, we'll pass the data into the LSTM gates in sequence, and print the output after each event:

<p>CODE: https://gist.github.com/conditg/086b31260935728db17ae8ef220324cb.js</p>

<p>CODE: https://gist.github.com/conditg/6cad615229eac2b30cf821e0e874414b.js</p>

Putting the same data through the PyTorch model shows that we return identical output:

<p>CODE: https://gist.github.com/conditg/2643d98184322d838fd1634bb84f83d6.js</p>

<p>CODE: https://gist.github.com/conditg/d8fb8af8850e0fd133c83d133d22132a.js</p>

We can additionally verify that after the data has gone through the LSTM cells, the two models have the same hidden and cell states:

<p>CODE: https://gist.github.com/conditg/979510a48d00cd4b9a5d7de82765976e.js</p>

<p>CODE: https://gist.github.com/conditg/6fc00f3945c78ae78a55ac905dfafab6.js</p>

I hope this helps build an intuition for how LSTM networks make predictions. Below is the full example code:

<p>CODE: https://gist.github.com/conditg/47eb195eb1d5b80ea299c567c8d0f3bf.js</p>

Issues or Questions? Please reach out! I'd love to hear from you.

Hungry to keep the fun going with backpropagation? Check out Backpropogating an LSTM: A Numerical Example.

Greg Condit

Recent Blog Posts

Let's Work Together
Contact Me