Using Fast Weights to Attend to the Recent Past

https://arxiv.org/abs/1610.06258

TLDR; Our RNN hidden states hold memory of the previous states and we have our input and hidden weights to shape the input and next hidden state. However, we have to wait until a batch ends in order to update our input and hidden weights. We can use fast weights (which update after each input in a sequence) in order to update the subsequent hidden state. This is analogous to a working memory in our brains which we can use for associative memory. We can quickly learn how to capture meaningful features in the data without saving to wait for the slow weights to update.

Detailed Notes:

  • RNNs have two types of memory. You have the hidden states which has the memory from all the previous states. This is like our short term memory and it is updated at every time step. The second type of memory is long term memory which is our weights. They tell us how to update our hidden states, and change our inputs to get our outputs. These are updated less frequently (end of sequences) with the BPTT.
  • A third memory could prove useful. The synapses in our brain fire at different time scales so it doesn’t make sense to restrict our selves to just these two types of memory (short but rapid updates and long but delayed updates). We can make this third type of memory, which will have higher storage capacity than the hidden state memory and will update faster than our weights.

Fast associative memory:

  • Essentially, this is a local set of weights that update after each input and shape the composite hidden state and input of each state. We don’t have to wait for the slow weights (hidden and input weights) to update after each batch in order to process our inputs.
  • Fast weights allow us to use the previous hidden states in order to process the current input with different set of weights. Our slow weights still produce the outcome as usual to create the next hidden state, but now our fast weights have some weight on that subsequent hidden state.
Using Fast Weights to Attend to the Recent Past
Jimmy Ba, Geoffrey Hinton, Volodymyr Mnih, Joel Z. Leibo, Catalin Ionescu
NIPS 2016, https://arxiv.org/abs/1610.06258

Physiological Motivations

How do we store memories? We don’t store memories by keeping track of the exact neural activity that occurred at the time of the memory. Instead, we try to recreate the neural activity through a set of associative weights which can map to many other memories as well. This allows for efficient storage of many memories without storing separate weights for each instance. This associative network also allows for associative learning which is the ability to learn the recall the relationship between initially unrelated instances.(1)

ba

Concept

  • In a traditional recurrent architecture we have our slow weights. These weights are used with the input and the previous hidden state to determine the next hidden state. These weights are responsible for the long-term knowledge of our systems. These weights are updated at the end of a batch, so they are quite slow to update and decay.
  • We introduce the concept of fast weights, in conjunction with the slow weights, in order to account for short-term knowledge. These weights are quick to update and decay as they change from the introduction of new hidden states.
  • For each connection in our network, the total weight is the sum of the results from both the slow and fast weights. The hidden state for each time step for each input is determined by the operations with the slow and fast weights. We use a fast memory weights matrix A to alter the hidden states to keep track of the features required for any associative learning tasks.
  • The fast weights memory matrix, A(t), starts with 0 at the beginning of the sequence. Then all the inputs for the time step are processed and A(t) is updated with a scalar decay with the previous A(t) and the outer product of the hidden state with a scalar operation with learning rate eta.

eq1

  • For each time step, the same A(t) is used for all the inputs and for all S steps of the “inner loop”. This inner loop is transforming the previous hidden state h(t) into the next output hidden state h(t+1). For our toy task, S=1 but for more complicated tasks where more associate learning is required, we need to increase S.

diagram1.png

Efficient Implementation

  • If a neural net, we compute our fast weights matrix A by changing synapses but if we want to employ an efficient computer simulation, we need to utilize the fact that A will be <full rank matrix since number of time steps t < number of hidden units h. This means we might not need to calculate the fast weights matrix explicitly at all.

eq2

  • We can rewrite A by recursively applying it (assuming A=0 @ beginning of seq.). This also allows us to compute the third component required for the inner loop’s hidden state vector. We do not need to explicitly compute the fast weights matrix A at any point! The real advantage here is that now we do not have to store a fast weights matrix for each sequence in the minibatch, which becomes a major space issue for forward/back prop on multiple sequences in the minibatch. Now, we just have to keep track of the previous hidden states (one row each time step).

eq3

Notice the last two terms when computing the inner loops next hidden vector. This is just the scalar product of the earlier hidden state vector, h(/tau ), and the current hidden state vector, hs(t+ 1) in the inner loop. So you can think of each iteration as attending to the past hidden vectors in proportion to the similarity with the current inner loop hidden vector.

We do not use this method in our basic implementation because I wanted to explicitly show what the fast weights matrix looks like and having this “memory augmented” view does not really inhibit using minibatches (as you can see). But the problem an explicit fast weights matrix can create is the space issue, so using this efficient implementation will really help us out there.

Note that this ‘efficient implementation’ will be costly if our sequence length is greater than the hidden state dimensionality. The computations will scale quadratically now because since we need to attend to all previous hidden states with the current inner loop’s hidden representation.

Execution

  • To see the advantage behind the fast weights, Ba et. al. used a very simple toy task.

Given: g1o2k3??g we need to predict 1.

  • You can think of each letter-number pair as a key/value pair. We are given a key at the end and we need to predict the appropriate value. The fast associative memory is required here in order to keep track of the key/value pairs it has just seen and retrieve the proper value given a key. After backpropagation, the fast memory will give us a hidden state vector, for example after g and 1, with a part for g and another part for 1 and learn to associate the two together.
  • For training:
python train.py train <model_name: RNN-LN-FW | RNN-LN | CONTROL | GRU-LN >
  • For sampling:
python train.py test <model_name: RNN-LN-FW | RNN-LN | CONTROL | GRU-LN >
  • For plotting results:
python train.py plot

Code Breakdown

Initializing our fast weights matrix A(t) and the initial hidden state.

# fast weights and hidden state initialization
self.A = tf.zeros(
    [FLAGS.batch_size, FLAGS.num_hidden_units, FLAGS.num_hidden_units],
    dtype=tf.float32)
self.h = tf.zeros(
    [FLAGS.batch_size, FLAGS.num_hidden_units],
    dtype=tf.float32)

For each time step, we will determine the hidden state using the input and the previous hidden state using our slow weights.

# hidden state
self.h = tf.nn.relu((tf.matmul(self.X[:, t, :], self.W_x)+self.b_x) +
    (tf.matmul(self.h, self.W_h)))

After we determine the hidden state, we need to join with the fast weights component to create our final hidden state which will become our composite hidden state for this time step. Since the fast weights matrix involves taking the outer product of the hidden states, we will need to reshape our hidden state into a three dimensional tensor. We first determine our A(t) which is a decay from the previous A(t) and influences by the previous composite hidden state. The value of A(t) remains fixed for the entire time step for all S ‘inner loop’ steps.

Within the inner loop, we conduct S changes to our composite hidden state. The updates are as follows:

Screen Shot 2016-12-03 at 9.25.50 PM.png

This continues for S steps andd the last inner loop’s hidden state becomes the new hidden state from this time step. Note that the layer normalization is very crucial as the scalar product of our hidden vectors can explode/vanish. Layer normalization allows to use each sample’s mean and variance (across all hidden components) to create an output hidden state that has zero mean and unit variance scaled and shifted to a minute degree. Then the nonlinearity is applied and the updates continue. Our control model, which employed neither layer normalization nor FW failed to converge.

# Reshape h to use with a
self.h_s = tf.reshape(self.h,
    [FLAGS.batch_size, 1, FLAGS.num_hidden_units])

# Create the fixed A for this time step
self.A = tf.add(tf.scalar_mul(self.l, self.A),
    tf.scalar_mul(self.e, tf.batch_matmul(tf.transpose(
        self.h_s, [0, 2, 1]), self.h_s)))

# Loop for S steps
for _ in range(FLAGS.S):
    self.h_s = tf.reshape(
        tf.matmul(self.X[:, t, :], self.W_x)+self.b_x,
        tf.shape(self.h_s)) + tf.reshape(
        tf.matmul(self.h, self.W_h), tf.shape(self.h_s)) + \
        tf.batch_matmul(self.h_s, self.A)

    # Apply layernorm
    mu = tf.reduce_mean(self.h_s, reduction_indices=0) # each sample
    sigma = tf.sqrt(tf.reduce_mean(tf.square(self.h_s - mu),
        reduction_indices=0))
    self.h_s = tf.div(tf.mul(self.gain, (self.h_s - mu)), sigma) + \
        self.bias

    # Apply nonlinearity
    self.h_s = tf.nn.relu(self.h_s)

# Reshape h_s into h
self.h = tf.reshape(self.h_s,
    [FLAGS.batch_size, FLAGS.num_hidden_units])

Results

Control: RNN without layer normalization (LN) or fast weights (FW)

results1results2

Bag of Tricks

  • Initialize slow hidden weights with an identity matrix in RNN to avoid gradient issues.(2)
  • Layer norm is required when using an RNN for convergence.
  • Weights should be properly initialized in order to have unit variance after the dot product, prior to non-linearity or else things can blow up really quickly.
  • Keep track of the gradient norm and tune accordingly.
  • No need to add extra input processing and extra layer after softmax as Jimmy Ba did. (He was doing that simply to compare with another task so it will just add extra computation if you blindly follow that).

Extensions

  • I will be releasing my code comparing fast weights with an attention interface for language related sequence to sequence tasks.

Citations

  1. Suzuki, Wendy A. “Associative Learning and the Hippocampus.” APA. American Psychological Association, Feb. 2005. Web. 04 Dec. 2016.
  2. Hinton, Geoffrey. “FieldsLive Video Archive.” Fields Institute for Research in Mathematical Sciences. University of Toronto, 13 Oct. 2016. Web. 04 Dec. 2016.

All Code

GitHub Repo

Training Points:

  • There’s a really cool example displaying the utility behind having a working memory for associative recall and it shows how an simple RNN with fast weights outperforms an LSTM with/without long term storage capacity!

Unique Points:

  • With fast weights, we are able to store useful information and the net knows when to store this information. This is different from an NTM because there are no explicit READ and WRITES. The weights are always changing and they’re able to change the hidden state as we need in order to learn faster.
  • Really nice paper, and I really enjoyed the neuroscience foundation. I think there are many more biological architectures to exploit which can really make our systems more efficient. I’ve been thinking about a few architectural ideas and will make a few blog posts soon about how to create novel architectures like this and test them out.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s