In this blog, we will dive into the concept of “catastrophic forgetting” in LLM model training and explore how it impacts the abilities of models during re-training. We compare full fine-tuning and LoRA (Low-Rank Adaptation) methods, discussing their applications, benefits, and drawbacks. You’ll learn why certain abilities can weaken during secondary training and discover techniques to minimize this effect. Perfect for AI enthusiasts and professionals who want a deeper understanding of model adaptation and efficient training methods.
For detailed information, please watch the YouTube video: Understanding Catastrophic Forgetting in LLM: Simple Explained
In large language model training, we often encounter the term “catastrophic forgetting.” Simply put, catastrophic forgetting occurs when a model, during training, may lose some of the abilities it previously learned.
As illustrated here, we have a pre-trained model built on vast amounts of data, making it a large model with a range of abilities that meet certain standards. Now, let’s say we want to enhance specific abilities within the model and undertake additional training. After this secondary training, the targeted abilities may improve, but at the cost of potentially weakening other abilities. This phenomenon is common in the re-training process.
We can imagine that a pre-trained model achieves a balanced state across its various abilities through significant training efforts. When we proceed with further training, this balance inevitably gets disrupted, as strengthening certain abilities may lead to the weakening of others. There are two main types of additional training here: continual pretraining and fine-tuning. Both modify the model through training, fundamentally changing its parameters.
Continual pre-training is generally used to adapt large models to specific domains by feeding them with large volumes of domain-specific data, like financial or medical data, to create models that excel in these areas. This training method resembles the original pre-training, but continues from the previously trained state, requiring substantial data.
Fine-tuning, on the other hand, relies on smaller datasets and is often conducted with instructional fine-tuning or alignment techniques to adapt the model. Essentially, catastrophic forgetting is unavoidable; what we can do is employ certain strategies to reduce its likelihood. Let’s look at methods to mitigate catastrophic forgetting, using fine-tuning as an example (although many techniques also apply to continual pre-training).
- Ignore Forgetting: In some cases, such as sentiment classification, if the goal is to create a model solely focused on sentiment analysis, we don’t necessarily need to worry about other abilities diminishing. Here, we can largely ignore catastrophic forgetting.
- Combine General and Task-Specific Data: A common approach is to blend general instructional data with task-specific data during training. This combination helps preserve multiple abilities without significantly degrading others. For continual pre-training, combining general and domain-specific data can be beneficial, though finding the right ratio between these data types (e.g., how much general vs. domain-specific data to include) is crucial.
- Use Adapters: Adapters are similar to plugins that can enhance specific abilities without altering the base model. Adapters can target and improve particular model parts or components, minimizing the impact on the rest of the model. Techniques like LoRA fall under this category.
- Regularization: Regularization techniques limit the degree of parameter change during training to keep the model’s parameters close to their original values, reducing the impact on previously learned abilities. Mathematically, this can mean setting a delta limit on parameter shifts.
- Adjust Learning Rate: Modifying the learning rate during training can help control forgetting.
Catastrophic forgetting is a complex and challenging issue without a definitive solution, often requiring trial and error. However, it’s a problem that must be taken seriously, as it can be one of the most costly issues in model training, especially when fine-tuning with data that diverges significantly from the model’s original domain, which can lead to more severe forgetting.