Adding Adapters to a Model
This document gives an overview of how new model architectures of Hugging Face Transformers can be supported by adapters
.
Before delving into implementation details, you should familiarize yourself with the main design philosophies of adapters
:
Adapters should integrate seamlessly with existing model classes: If a model architecture supports adapters, it should be possible to use them with all model classes of this architecture.
Copied code should be minimal:
adapters
extensively uses Python mixins to add adapter support to HF models. Functions that cannot be sufficiently modified by mixins are copied and then modified. Try to avoid copying functions as much as possible.
Relevant Classes
Adding adapter support to an existing model architecture requires modifying some parts of the model forward pass logic. These modifications are realized by the four files in the src/adapters/models/<model_type>/
directory. Let’s examine the purpose of these files in the example of BERT. It’s important to note that we are adapting the original Hugging Face model, implemented in transformers/models/bert/modeling_bert.py. The files in src/adapters/models/bert/
are:
src/adapters/models/bert/mixin_bert.py
: This file contains mixins for each class we want to change. For example, in theBertSelfAttention
class, we need to make changes for LoRA and Prefix Tuning. For this, we create aBertSelfAttentionAdaptersMixin
to implement these changes. We will discuss how this works in detail below.src/adapters/models/bert/modeling_bert.py
: For some classes of the BERT implementation (e.g.BertModel
orBertLayer
) the code can be sufficiently customized via mixins. For other classes (likeBertSelfAttention
), we need to edit the original code directly. These classes are copied intosrc/adapters/models/bert/modeling_bert.py
and modified.src/adapters/models/bert/adapter_model.py
: In this file, the adapter model class is defined. This class allows flexible adding of and switching between multiple prediction heads of different types. This looks about the same for each model, except that each model has different heads and thus differentadd_..._head()
functions.src/adapters/models/bert/__init__.py
: Defines Python’s import structure.
Implementation Steps 📝
Now that we have discussed the purpose of every file in src/adapters/models/<model_type>/
, we go through the integration of adapters into an existing model architecture step by step. The following steps might not be applicable to every model architecture.
Files:
Create the
src/adapters/models/<model_type>/
directory and in it the 4 files:mixin_<model_type>.py
,modeling_<model_type>.py
adapter_model.py
and__init__.py
Mixins:
In
src/adapters/models/<model_type>/mixin_<model_type>.py
, create mixins for any class you want to change and where you can’t reuse an existing mixin from another class.To figure out which classes to change, think about where to insert LoRA, Prefix Tuning, and bottleneck adapters.
You can use similar model implementations for guidance.
Often, existing mixins of another class can be reused. E.g.
BertLayer
,RobertaLayer
,XLMRobertaLayer
,DebertaLayer
,DebertaV2Layer
andBertGenerationLayer
(all models derived from BERT) use theBertLayerAdaptersMixin
.
To additionally support Prefix Tuning, it’s necessary to apply the forward call to the
PrefixTuningLayer
module in the respective attention layer (see step 3 for how to modify the code of an Hugging Face class).Make sure the calls to
bottleneck_layer_forward()
are added in the right places.The mixin for the whole base model class (e.g.,
BertModel
) should derive fromModelBaseAdaptersMixin
and (if possible)EmbeddingAdaptersMixin
and/orInvertibleAdaptersMixin
. This mixin should at least implement theiter_layers()
method but might require additional modifications depending on the architecture.If the model is a combination of different models, such as the EncoderDecoderModel, use
ModelUsingSubmodelsAdaptersMixin
instead ofModelBaseAdaptersMixin
.
Copied functions:
For those classes where the mixin is not enough to realize the wanted behavior, you must:
Create a new class in
src/adapters/models/<model_type>/modeling_<model_type>.py
with the name<class>WithAdapters
. This class should derive from the corresponding mixin and HF class.Copy the function you want to change into this class and modify it.
e.g., the
forward
method of theBertSelfAttention
class must be adapted to support prefix tuning. We therefore create a classBertSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, BertSelfAttention)
, copy the forward method into it and modify it.if the
forward
method of a module is copied and modified, make sure to calladapters.utils.patch_forward()
in the module’sinit_adapters()
method. This ensures adapters work correctly with theaccelerate
package.
Modify MODEL_MIXIN_MAPPING
For each mixin whose class was not copied into
modeling_<model_type>.py
, add the mixin/class combination intoMODEL_MIXIN_MAPPING
in the filesrc/adapters/models/__init__.py
.
Create the adapter model:
Adapter-supporting architectures should provide a new model class
<model_type>AdapterModel
. This class allows flexible adding of and switching between multiple prediction heads of different types.This is done in the
adapter_model.py
file:This module should implement the
<model_type>AdapterModel
class, deriving fromModelWithFlexibleHeadsAdaptersMixin
and<model_type>PreTrainedModel
.In the model class, add methods for those prediction heads that make sense for the new model architecture.
Again, have a look at existing implementations.
Add
<model_type>AdapterModel
to theADAPTER_MODEL_MAPPING_NAMES
mapping insrc/adapters/models/auto/adapter_model.py
and tosrc/adapters/__init__.py
.Define the classes to be added to Python’s import structure in
src/adapters/models/<model_type>/__init__.py
. This will likely only be the<model_type>AdapterModel
.
Adapt the config classes:
Adapt the config class to the requirements of adapters in
src/adapters/wrappers/configuration.py
.There are some naming differences in the config attributes of different model architectures. The adapter implementation requires some additional attributes with a specific name to be available. These currently are
num_attention_heads
,hidden_size
,hidden_dropout_prob
andattention_probs_dropout_prob
as in theBertConfig
class. If your model config does not provide these, add corresponding mappings toCONFIG_CLASS_KEYS_MAPPING
.
Additional (optional) implementation steps 📝
Parallel adapter inference via
Parallel
composition block (cf. documentation, PR#150).Provide mappings for an architecture’s existing (static) prediction heads into
adapters
flex heads (cf. implementation).
Testing
❓ In addition to the general Hugging Face model tests, there are adapter-specific test cases. All tests are executed from the tests
folder. You need to add two different test classes.
📝 Steps
Add a new
test_<model_type>.py
module intests/
This file is used to test that everything related to the usage of adapters (adding, removing, activating, …) works.
This module typically holds 2 test classes and a test base class:
<model_type>AdapterTestBase
: This class contains thetokenizer_name
,config_class
andconfig
.<model_type>AdapterTest
derives from a collection of test mixins that hold various adapter tests (depending on the implementation).(optionally)
<model_type>ClassConversionTest
runs tests for correct class conversion if conversion of prediction heads is implemented.
Add a new
test_<model_type>.py
module intests/models/
This file is used to test the AdapterModel class.
This module typically holds 1 test class with the name
<model_type>AdapterModelTest
<model_type>AdapterModelTest
derives directly from Hugging Face’s existing model test class<model_type>ModelTest
and adds<model_type>AdapterModel
as a class to test.
Documentation
❓ The documentation for adapters
lives in the docs
folder.
📝 Steps
Add
docs/classes/models/<model_type>.rst
(oriented at the doc file in the HF docs). Make sure to include<model_type>AdapterModel
autodoc. Finally, list the file inindex.rst
.Add a new row for the model in the model table of the overview page at
docs/model_overview.md
, listing all the methods implemented by the new model.
Training Example Adapters
❓ To make sure the new adapter implementation works properly, it is useful to train some example adapters and compare the training results to full model fine-tuning. Ideally, this would include training adapters on one (or more) tasks that are good for demonstrating the new model architecture (e.g. GLUE benchmark for BERT, summarization for BART) and uploading them to AdapterHub.
We provide training scripts for many tasks here: https://github.com/Adapter-Hub/adapters/tree/main/examples/pytorch/