Overcoming Catastrophic Forgetting in Neural Networks

Screen Shot 2017-04-01 at 2.57.35 PMOvercoming Catastrophic Forgetting in Neural Networks
   James Kirkpatrick, Raia Hadsell, et al.
   https://arxiv.org/abs/1612.00796

TLDR; Catastrophic forgetting is forgetting key information needed to solve a previous task when training on a new task. However, there are several approaches to combat this issue which allows for continual learning. In this post we will look at DeepMind’s elastic weight consolidation (EWC) technique but also talk about previous approaches with a probabilistic perspective.

Note: I will cover earlier approaches to combat catastrophic forgetting first and then focus on EWC and a few other recent techniques.

Previous/Alternative Approaches to Combat Catastrophic Forgetting

  • We want to have multi-task learning models that can solve many different tasks well. This single model will provide benefits for all of the involved tasks by allowing for shared representations, etc. However, a common obstacle is catastrophic forgetting, which is forgetting key information needed to solve a previous task when training on a new task.
  • There have been a surge of techniques to combat catastrophic forgetting and to allow for continual learning. One of which is as as simple as an ensemble of DNNs. For every new task, a new network is used with a shared representation from the previous task. This approach actually works quite well for all of the tasks but we can never expect this to scale for a large set of different tasks (for training and especially inference).
  • A few of the recent approaches include PathNet which uses an evolutionary approach to deal with the forgetting. In the PathNet, each DNN can have ~20 modules for a given layer and a particular task may choose ~4 modules in each layer. This is sort of like an extension to the ensemble of models but it solves the increasing complexity issue as the number of tasks increase.
  • There have also been some recent approaches that involve using aspects of regularization. The Joint-Many-Task (JMT) model uses successive regularization terms for each tasks’s loss in order to combat the forgetting. The successive regularization uses the embedding parameters of the previous task at the current epoch and the previous epoch. This addition prevents the model from forgetting information learned for the previous task(s) when training on the current task. The successive regularization is the last term in the expression below.

eq2

  • Compared to the approaches above, a more probabilistic approach is elastic weight consolidation (EWC) which will be the main focus of this post. EWC preserves the key information needed to solve previous tasks when training new tasks. This allows for good performance on all tasks and makes way for continual learning.

Elastic Weight Consolidation (EWC)

  • The intuition behind EWC comes from task-specific synaptic consolidation in our neocortical circuits. The key to do a particular task is embedded in less pliable synapses, which allows us to remember how to solve previously learned tasks. When we learn a new task, the key synapses from previous tasks still change but are less prone to change because of the vital role they previously played. In terms of DNNs, we just slow down the learning for the weights that were important to previous task(s).
  • Let’s say we have two tasks: A and B. How do we determine which weights are most important to task A? And by how much do we slow down the learning for the weights? To answer these questions, we need to think about the learning process. Learning involves find parameters Screen Shot 2017-04-02 at 10.48.01 AM.png that allow us to optimize on some objective. One key fact we need to consider is that there exists many combinations of Screen Shot 2017-04-02 at 10.48.01 AM.png that give us the same optimal performance. This is good news because now we can find some configuration that gives us the best performance for both task A and B. This means that when we train task B, we will need to constrain the parameters such that they still produce a small error for task A. EWC employs a quadratic error as the constraint, as we will see the in the overall loss expression.
  • But before determining the loss, we need to figure out how we are going to weigh the parameters such that the weights that are important to task A are less prone to change. But if task B does involves changing these key weights, then there will be a penalty. We can solve this by thinking in a probabilistic perspective. Instead of finding the optimal weights Screen Shot 2017-04-02 at 10.48.01 AM.png using gradient descent, we can now use Bayesian posterior distributions over the weights that worked well for previous tasks. When we solve the new task (Screen Shot 2017-04-02 at 10.46.38 AM.png), we update the posterior to account for all tasks thus far (Screen Shot 2017-04-02 at 10.47.25 AM.png).

diagram1

  • What we want to do is find the optimal weights Screen Shot 2017-04-02 at 10.48.01 AM.png given some input data D (includes task A and task B data). We can use Bayes rule to expand this expression into the following:

eq1

  • We can go ahead and rewrite the expression above into the following if we assume Screen Shot 2017-04-02 at 10.51.27 AM.png to be the negative of the loss for the current task (B).

eq2

  • From the expression above, we can see that the only component that accounts for task A is the posterior distribution Screen Shot 2017-04-02 at 10.49.12 AM.png. We can go ahead and approximate this distribution as a Gaussian dictated by Screen Shot 2017-04-02 at 10.49.40 AM.png and the Fisher information matrix (F). The Fisher matrix is key to figuring out which weights influence task A the most. Understanding the basics of the Fisher matrix will allow us to see how it fits in the overall loss expression.

eq3.png

  • So now our total loss looks like the following expression. The Fisher matrix allows us to account for the parameters that were most important to task A. And the quadratic term allows us to penalize if we decide to shift the weights that were key performers for task A. This loss allows us to find that balanced configuration that offers great performance for both task A and B. Note: we only use the diagonal values from the Fisher information matrix which means that we assume each parameter as independent components for the overall loss (diagonal Laplace approximate).

eq4

Thoughts:

  • EWC is just one of the recent successful techniques in combatting catastrophic forgetting. As we saw with some of the other papers, notably JMT with their successive regularization technique, there are many ways to tackle this issue. There has been quite a bit of investment in using a probabilistic approach (generally assume that the posterior distribution of DNN weights can be approximate with a Gaussian), so that?s why I focused particularly on the EWC on this post.
  • But, with the rise of multi-task learning models and the clear benefits they provide for all the involved tasks, catastrophic forgetting is certainly an issue worth solving. I think it will be quite some time before standards set in terms of dealing with the issue but even then we may have different options depending on the nature of the dataset (difference in involves tasks, etc) and the important we give to shared representations, etc.

Code:

One thought on “Overcoming Catastrophic Forgetting in Neural Networks

  1. Andrea Soltoggio says:

    Short and long term plasticity as ways to consolidate weights also shown to address catastrophic forgetting in
    Short-term plasticity as cause-effect hypothesis testing in distal reward learning
    A Soltoggio
    Biological Cybernetics, 2015

    Liked by 1 person

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s