Home Your Neural Network Will Forget What It’s Learned
Our brains are ever-evolving objects, continuously revising their structures to retain new information and get rid of the old based on our interactions with the environment. As we know, artificial neural networks with their webs of activations and connections were originally modeled after the brain. Most ANNs, however, are made to be static in structure, relying on batch learning in which they are fed many batches of independent identically distributed (IID) data at training time and their learned parameters are fixed upon deployment, which is inconsistent with the way our brains learn: We do not learn by processing random batches of data all at once, but by processing the continuous streams of causally related information we receive about our environment from our senses.
While batch learning works fine for tasks in which the precise nature of the data a model will encounter at inference is known beforehand, like in classifying specific classes of random photos, this is an unrealistic precedent in many real-world applications where a model may encounter new data after deployment. In such instances, we would want the model to adapt to deal with the new data on the fly, a process known as online learning. Additionally, in most cases the data encountered is not random but occurs in sequences of related instances, such as frames of a video or fluctuations in a stock’s price—a property that must also be taken into account when developing online learning solutions.
There have been several attempts to create evolving neural networks for online learning but they inevitably run into the problem of what is known as catastrophic forgetting (sometimes called catastrophic interference), where adapting to perform new tasks causes the network to “forget” those it previously learned. This phenomenon was first discovered by researchers McCloskey and Cohen all the way back in 1989 when testing a network’s ability to sequentially learn associate list tasks. In their experiment, the first task consisted of learning pairs of arbitrary words from two sets A and B, such as “locomotive – dishtowel, window – reason, bicycle – tree, etc.” The network was trained until it could perfectly recall the B associate for each A item. It then began learning a second task in which set A was paired with different words from set C, such as “locomotive – cloud, window – book, bicycle – couch, etc.” and it was tested on its ability to remember pairings from the AB list after 1, 5, 10 and 20 iterations of learning the AC list. The following graph b) shows how the network rapidly forgets the AB task after beginning to learn the AC task, as compared to human performance in the same experimental setup, which demonstrates that our brains are able to retain knowledge of previous tasks much more effectively.
It’s undoubtedly challenging to construct a network that is finite in structure but that can retain knowledge of past experiences given a continuous stream of data. Initial tactics for overcoming catastrophic forgetting relied on allocating progressively more resources to networks as new classes were learned, an approach that is ultimately unsustainable for most real-world applications. Let’s now take a look at some more recent strategies for compelling networks to remember.
One heavily researched mechanism for dealing with catastrophic forgetting is regularization. As we know, a network adapts to learn new tasks by adjusting the weights of its connections, and regularization involves varying the changeability, or plasticity, of weights based on how important they are determined to be for previous tasks.
In a highly cited 2017 paper, Kirkpatrick et al. introduced a regularization technique called Elastic Weight Consolidation (EWC). EWC maintains the fidelity of connections important for previously learned tasks by constraining weights to stay close to their learned values as new tasks are encountered.
To illustrate how EWC works, let’s say we are learning some classification task A, for which our network is learning a set of weights θ. There are, in fact, multiple configurations of θ that will yield good performance on A—a general ballpark of weights represented by the gray ellipse in the diagram above. Catastrophic forgetting occurs when the network moves on to learn a different task B associated with a different ballpark of weights—the cream ellipse—and its weights are consequently adjusted such that they fall outside the ballpark for good performance on A, as illustrated by the blue arrow.
In EWC, a quadratic penalty is introduced to constrain the network parameters to stay within the low error region for task A when learning to perform B, depicted by the red arrow. The quadratic penalty acts as a “spring” of sorts to anchor the parameters to previously learned solutions, hence the name Elastic Weight Consolidation. The spring’s degree of elasticity, i.e. the degree of the quadratic penalty, differs between weights depending on how “important” the weights are determined to be for prior tasks. In the diagram, for example, task A’s 2D weight ellipsoid is longer along the x dimension than the y dimension, indicating the x weight is more important for A and thus will be afforded less elasticity than the y weight when adjusting to learn B. Failing to make the spring adaptable in this way and applying the same elasticity coefficient to each weight would result in weights not very well suited to either task, as indicated by the green arrow in the diagram.
The EWC model is trained on a sequence of tasks, each task consisting of a batch of data. A task is a fixed random shuffling of pixels in the handwritten MNIST digit images. Once the model trains on data for one task, it moves on to the batch for the next task and data for previous tasks is not encountered again, which allows testing of how well EWC “remembers” how to perform previously learned tasks. The following plot shows EWC’s test performance on a sequence of three tasks A, B, and C as it proceeds through training on them.
We can see how the performance of EWC remains fairly stable across previously learned tasks even as it learns new ones, as opposed to an approach that uses the same quadratic penalty for all weights (the green line) and one that doesn’t include a penalty at all, just using standard stochastic gradient descent (the blue line)—these both demonstrate catastrophic forgetting of task A, for example, as tasks B and C are learned.
Replay is another popular method for mitigating forgetting that involves storing some representation of previously encountered training data. The data is stored in what is referred to as a replay buffer. This technique was first proposed in the paper “iCaRL: Incremental Classifier and Representation Learning” by Rebuffi et al. in late 2016. In its replay buffer, iCaRL stores sets of images for each class encountered during training, referred to as “exemplar” images. The goal is for these images to be as representative of their respective classes as possible. For training, iCaRL processes batches of classes at a time. As a new class is encountered, a simulated training set is created with all the stored exemplars and the new data. All of this data is run through the network and its outputs for the previously learned classes are stored for the next step, in which the network’s parameters are updated. The network is updated by minimizing a loss function that integrates both classification loss and distillation loss—classification loss to prompt it to output correct labels for the newly encountered class and distillation loss to encourage it to reproduce the labels stored for the previously learned classes.
The network determines how to classify a given image by consulting the stored sets of exemplar images. Specifically, at inference time the exemplar images for a particular class are passed through the network to yield a set of feature vectors, which are averaged to produce a representative feature vector for that class. This is repeated for all classes and the feature vector for the test instance is compared to them all and assigned the label of the class it is most similar to. Importantly, a limit is imposed on the number of exemplar images that are stored such that if a new class is encountered after the limit is reached, images are removed from the sets of the other classes to accommodate for learning the new one. This prevents the computational requirements and memory footprint of the model from increasing unbounded as new classes are encountered.
With this strategy, we can see evidence of iCaRL’s retention of previous classes (a) among the confusion matrices above. Illuminated pixels on the diagonal represent correct class predictions, and we can see that iCaRL achieves the most correct predictions as well as an even distribution of incorrect predictions as compared to other networks that are biased towards predicting initial classes (c) or more recently encountered classes (b and d).
iCaRL introduced the idea of storing training instances for replay to remember learned tasks and sparked a lot of research on different applications of this technique; however, all relied on storing raw representations of the training data in replay buffers, whereas the brain is known to store and replay memories as compressed representations of neocortical activity patterns. REMIND, standing for Replay using Memory INDexing, is a streaming learning model introduced by Hayes et al. in late 2019 that aims to mimic this functionality of the brain by storing feature maps of image data for replay, the first of its kind to do so. Further, many models for streaming learning involve processing input data in batches—for example, training on a batch of cat images and then on a batch of dog images—which is neither representative of how the brain works nor of most real-world deployment scenarios, where data is encountered one instance at a time in a continuous stream. Batch processing is also more resource intensive, making it unsuitable for many mobile applications, so for these reasons REMIND classifies instances one-by-one.
The REMIND network is split into two parts, as depicted in the diagram above: a series of frozen layers followed by a series of plastic layers, with a replay buffer in between. Training of the network begins with a “base initialization period,” in which all layers are trained on a certain number of classes in a normal offline manner to initialize their weights. After this, the weights of the frozen layers are, effectively, frozen—they remain unchanged for the remainder of training. The idea behind this is that initial neural network layers are made to generalize well across variable inputs, so it’s not necessary to update their weights as new data is encountered as they would not change significantly anyway. The feature map representations of the input images produced by the frozen layers of the network are used to train a vector quantization model that compresses the feature maps and learns how to faithfully reconstruct them. The compressed representations are what get stored in REMIND’s replay buffer, mimicking the brain’s mechanisms for storing memories as well as decreasing the size of the data so that more training instances can be stored in the buffer.
After the base initialization period, each new training instance is run through the frozen layers of the network and combined with a number of instances that have been uniformly selected from the replay buffer and reconstructed via the learned quantization model. The mixture is then used to train the plastic layers of the network. Quantized training examples and their labels are stored in the replay buffer until it reaches its maximum capacity, at which point each time a new example is added, an example from the class with the most instances is randomly removed; this allows the model to learn new classes without expanding unbounded.
In the following plot, we can see how the test accuracy of REMIND’s online learning technique compares to others as it progresses to learn classes in an image classification task. We see that it achieves the best accuracy among the online approaches and is second only to the offline batch learning approach, where the network is retrained on random batches of all the previously encountered data as each new class is learned.
REMIND’s superior results indicate that mimicking the brain’s means of storing compressed memory representations in neural networks may be a key factor in compelling them to remember.
Bi-level Continual Learning (BCL) is another online learning technique modeled after the way our brains operate. Proposed by Pham et al. in 2020, it involves two different models—a “fast-weight” model and a base model—which mirror the functionality of the hippocampus and neocortex, respectively. In our brains, the hippocampus is responsible for “rapid learning and acquiring new experiences” and the neocortex is tasked with “capturing common knowledge of all observed tasks.” To facilitate this in their model, Pham et al. employ both a generalization memory and an episodic memory buffer. The small episodic memory’s purpose is to store data from recent tasks for training the fast weight model, while the generalization memory stores data across all encountered tasks. The job of the fast weight model is to consolidate information from new samples for transfer to the base model.
For BCL, data arrives in mini-batches corresponding to a particular class and a sample of it is put in generalization memory while most of it goes on to train the fast weight model. The fast weight model is initialized with the weights θ of the base model and is trained on the current batch of data for the particular class mixed with recently seen data from the episodic memory to arrive at a new set of weights Φ. The base model’s weights are then adjusted to factor in those from the fast weight model. While the fast weight model’s job is to quickly adapt to learn new tasks, a knowledge distillation regularizer is employed to encourage it to minimize the distance between new weights Φ it learns and its initial weights θ from the base model so that updates to the base model won’t cause it to lose too much generalization ability across tasks. After a mini batch is processed, the fast weights are discarded and re-initialized with the base model’s weights for learning the next batch of data.
Pham et al. compare their BCL technique against several other continual learning models on a series of classification tasks and find that it generally outperforms them all. On the CIFAR100 dataset, a typical offline batch learning approach achieves 74.11% test accuracy and BCL achieves 67.75%, compared to the next best performing model that reaches 64.36%, and iCARL that achieves only 48.43% accuracy. The following graph shows how BCL and other replay techniques generally improve in performance on the CIFAR task as their episodic memory size increases, which makes sense given that a larger memory allows for a more accurate representation of the original dataset.
Unlike REMIND, though, BCL stores the raw, uncompressed data in its memory buffers. It would therefore be interesting to compare the performances of these two different approaches to modeling the brain’s structures for memory on the same task.
Even among the incremental learning scenarios discussed here, the conditions for training are still not very representative of real-world scenarios—there aren’t many situations in which arriving data is so clearly delineated into incremental batches for each class. Given this, a new research direction is to investigate online learning with more realistic streaming scenarios where training instances stream with variable distributions. Such research would be particularly relevant for, for example, space-based applications of online learning where spacecraft may have to learn collision avoidance of previously unseen objects on the fly.