LoRA
Low-Rank Adaption (LoRA) is a parameter-efficient fine-tuning method developed
originally for large language models, but adapted here for vision models. This module
contains TensorFLow code for LoRA layers and their integration with tfimm models.
For details on LoRA see the paper
LoRA: Low-Rank Adaptation of Large Language Models. Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen. Paper: [arXiv:2106.09685]
Usage
For supported architectures we can use tfimm.architectures.lora.create_model()
instead of tfimm.create_model() to create a LoRA model.
>>> from tfimm.architectures import lora
>>> model = lora.create_model(
... model_name="convnext_tiny", pretrained=True, lora_rank=2
... )
When we look at the model summary, we see that most model parameters are non-trainable and only the low-rank weight updates are trainable.
>>> model.summary()
...
=================================================================
Total params: 28,721,608
Trainable params: 132,480
Non-trainable params: 28,589,128
_________________________________________________________________
LoRA models can be converted back to regular models.
>>> type(model)
<class 'tfimm.architectures.lora.convnext.LoRAConvNeXt'>
>>> regular_model = lora.convert_to_regular_model(model)
>>> type(regular_model)
<class 'tfimm.architectures.convnext.ConvNeXt'>
Supported architectures
Currently we support the following architectures
ConvNeXt
And the following layers
Dense
Implementation
In order to perform LoRA training, the first task is to convert a regular model to its
LoRA version. For tfimm architectures we do this by subclassing and modifying layers
in __init__. E.g., LoRAConvNeXt is subclassed from
ConvNeXt and we replace the dense layers in each MLP block by their LoRA
counterparts.
We use a registry system to track model classes and their LoRA counterparts. A simplified example:
from tfimm.architectures import lora
@dataclass
class ResNetConfig:
nb_blocks = (3, 4, 6, 3)
class ResNet(tf.keras.Model):
cfg_class: ResNetConfig
def __init__(self, cfg, **kwargs):
...
class LoRAResNetConfig(ResNetConfig):
lora_rank = 2
@lora.register_lora_architecture
class LoRAResNet(ResNet):
cfg_class: LoRAResNetConfig
def __init__(self, cfg, **kwargs):
super().__init__(cfg, **kwargs) # Create the original model
... # Then replace layers with LoRA versions
We make the following assumptions
Model parameters are specified via a configuration dataclass and the configuration class of each model is defined via the
cfg_classclass attribute.The configuration of the LoRA model is a superset of the configuration of the base model.
Under these assumptions we can use the
register_lora_architecture() decorator to
associate LoRAResNet as the LoRA variant of the ResNet class.
Now, given an instance of ResNet, we can use
convert_to_lora_model() to convert
it to a LoRAResNet instance and transfer all weights.
model = ResNet(cfg=ResNetConfig())
... # Build model or load pre-trained weights
lora_model = lora.convert_to_lora_model(model, lora_rank=2)
The lora_model.trainable_weights property correctly returns only the LoRA trainable
weights, i.e., the low-rank updates. We additionally have the option to train the
biases as well, either only for LoRA layers or for all layers. This can be specified
by passing the values "none", "lora_only" or "all" for the
lora_train_bias parameter.
lora_model = convert_to_lora_model(
model, lora_rank=2, lora_train_bias="lora_only"
)
Sequential and functional models
The current implementation focusses on models created by subclassing, which is the case
for all tfimm models. In particular, the registry system works only for subclassed
models. However, some of the functionality also works for functional models.
LoRA layers are the basic building blocks for both subclassed as well as functional models.
Transferring weights works for all models, regardless of type, provided the regular model and LoRA variant have the same architecture with the exception of LoRA layers. Use the
tfimm.models.transfer_weights()function to tranfer weights to LoRA.from tfimm.architectures import lora from tfimm.models import transfer_weights # Transfer weights into the LoRA model transfer_weights( regular_model, lora_model, weights_to_ignore=lora.LORA_WEIGHT_NAMES )
After training, we need to manually merge weights and then transfer them back to the regular model.
lora.merge_weights(lora_model) transfer_weights(lora_model, regular_model)
The functions
lora_trainable_weights()andlora_non_trainable_weights()work for all models, regardless of type and return a list of weights to be used for LoRA training (or all other weights).
Interface
All functions are available under tfimm.architectures.lora.
Factory
- convert_to_lora_model(model, **kwargs)[source]
Creates a LoRA version of a model.
- Parameters:
model (Model) – Source model. Has to be an instance of a class that has a corresponding LoRA architecture registered.
**kwargs – LoRA parameters, such as
lora_rankandlora_alphaneed to be passed as kwargs and will be added to the model config.
- Returns:
LoRA model.
- Return type:
Model
- convert_to_regular_model(model)[source]
Converts a LoRA model to a regular model.
- Parameters:
model (Model) – LoRA model to be converted. Has be be a class that has been registered as a LoRA architecture.
- Returns:
The converted model.
- Return type:
Model
- create_model(model_name, pretrained=False, model_path='', **kwargs)[source]
Creates a LoRA model from a
tfimmmodel name.- Parameters:
model_name (str) – Name of model to instantiate.
pretrained (bool) – If
True, load pretrained weights as specified by theurlfield in config. Ifurlis[timm], the weights will be downloaded fromtimmand converted to TensorFlow. Seetfimm.create_model()for details.model_path (str) – Path of model weights to load after model is initialized. This takes precedence over
pretrained.**kwargs – LoRA parameters, such as
lora_rankandlora_alphaneed to be passed as kwargs and will be added to the model config.
- Returns:
The created model.
- Return type:
Model
- merge_lora_weights(model)[source]
Recursively merge weights in all LoRA layers in the given model. The model is modified in place.
- Parameters:
model (Model) – Model for merging weights.
- lora_non_trainable_weights(model, train_bias='none', trainable_layers=None)[source]
Returns a list of non-trainable weights for the LoRA model. This function complements
lora_trainable_weights().- Parameters:
model (Model) – A keras model.
train_bias (str) – If
"none"or"all", no or all bias weights are trainable respectively. If"lora_only", only the bias weights of LoRA layers are set to trainable.trainable_layers (List[str] | None) – A list of layer names that should be kept trainable.
- Returns:
List of LoRA non-trainable weights.
- Return type:
List[Variable]
- lora_trainable_weights(model, train_bias='none', trainable_layers=None)[source]
Returns a list of variables to be used instead of
model.trainable_weightswhen doing LoRA training.- Parameters:
model (Model) – A keras model.
train_bias (str) – If
"none"or"all", no or all bias weights are trainable respectively. If"lora_only", only the bias weights of LoRA layers are set to trainable.trainable_layers (List[str] | None) – A list of layer names that should be kept trainable.
- Returns:
List of LoRA trainable weights.
- Return type:
List[Variable]
Layers
- class LoRAConv2D(*args, **kwargs)[source]
LoRA version of the
Conv2Dlayer.- Parameters:
lora_rank (int) –
lora_alpha (float) –
- class LoRADense(*args, **kwargs)[source]
LoRA version of the
Denselayer.- Parameters:
lora_rank (int) –
lora_alpha (float) –
- convert_to_lora_layer(layer, **kwargs)[source]
Convenience function to convert supported layer types to their LoRA counterparts.
- Parameters:
layer (Layer) – Layer to be converted.
**kwargs – LoRA specific parameters such as
lora_rankhave to be passed as kwargs.
- Returns:
LoRA layer instance.
- Return type:
Layer
Registry
- lora_architecture(model_cls)[source]
Returns the LoRA model class registered for a given base model class.
- lora_base_architecture(lora_cls)[source]
Returns the base class corresponding to the given registered LoRA model class.
- lora_config(model_cls)[source]
Returns the config class for the LoRA model associated with the given base model class.
- register_lora_architecture(lora_cls=None, *, base_cls=None)[source]
Decorator to register a LoRA variant of a model architecture. It is used as follows
@register_lora_architecture class LoRAResNet(ResNet): ...
This will associate the class
LoRAResNetas the LoRA version of the classResNet. If the LoRA architecture is not created via subclassing, it can be specified explicitely.@register_lora_architecture(base_cls=ResNet) class LoRAResNet(tf.keras.Model): ...
A model class can be its own LoRA variant, if the model can be created with regular or LoRA layers depending on the config. In that case this function needs to be invoked after the model has been defined.
class FlexibleModel(tf.keras.Model): ... register_lora_architecture(FlexibleModel, base_cls=FlexibleModel)
- Parameters:
lora_cls – LoRA model class. We assume that it has a
cfg_classclass attribute.base_cls – Regular model class. If not provided we use the base class of
lora_cls.