LoRA

Low-Rank Adaption (LoRA) is a parameter-efficient fine-tuning method developed originally for large language models, but adapted here for vision models. This module contains TensorFLow code for LoRA layers and their integration with tfimm models. For details on LoRA see the paper

LoRA: Low-Rank Adaptation of Large Language Models. Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen. Paper: [arXiv:2106.09685]

Usage

For supported architectures we can use tfimm.architectures.lora.create_model() instead of tfimm.create_model() to create a LoRA model.

>>> from tfimm.architectures import lora
>>> model = lora.create_model(
...    model_name="convnext_tiny", pretrained=True, lora_rank=2
... )

When we look at the model summary, we see that most model parameters are non-trainable and only the low-rank weight updates are trainable.

>>> model.summary()
...
=================================================================
Total params: 28,721,608
Trainable params: 132,480
Non-trainable params: 28,589,128
_________________________________________________________________

LoRA models can be converted back to regular models.

>>> type(model)
<class 'tfimm.architectures.lora.convnext.LoRAConvNeXt'>
>>> regular_model = lora.convert_to_regular_model(model)
>>> type(regular_model)
<class 'tfimm.architectures.convnext.ConvNeXt'>

Supported architectures

Currently we support the following architectures

  • ConvNeXt

And the following layers

  • Dense

Implementation

In order to perform LoRA training, the first task is to convert a regular model to its LoRA version. For tfimm architectures we do this by subclassing and modifying layers in __init__. E.g., LoRAConvNeXt is subclassed from ConvNeXt and we replace the dense layers in each MLP block by their LoRA counterparts.

We use a registry system to track model classes and their LoRA counterparts. A simplified example:

from tfimm.architectures import lora

@dataclass
class ResNetConfig:
    nb_blocks = (3, 4, 6, 3)

class ResNet(tf.keras.Model):
    cfg_class: ResNetConfig

    def __init__(self, cfg, **kwargs):
        ...

class LoRAResNetConfig(ResNetConfig):
    lora_rank = 2

@lora.register_lora_architecture
class LoRAResNet(ResNet):
    cfg_class: LoRAResNetConfig

    def __init__(self, cfg, **kwargs):
        super().__init__(cfg, **kwargs)  # Create the original model
        ... # Then replace layers with LoRA versions

We make the following assumptions

  • Model parameters are specified via a configuration dataclass and the configuration class of each model is defined via the cfg_class class attribute.

  • The configuration of the LoRA model is a superset of the configuration of the base model.

Under these assumptions we can use the register_lora_architecture() decorator to associate LoRAResNet as the LoRA variant of the ResNet class.

Now, given an instance of ResNet, we can use convert_to_lora_model() to convert it to a LoRAResNet instance and transfer all weights.

model = ResNet(cfg=ResNetConfig())
... # Build model or load pre-trained weights

lora_model = lora.convert_to_lora_model(model, lora_rank=2)

The lora_model.trainable_weights property correctly returns only the LoRA trainable weights, i.e., the low-rank updates. We additionally have the option to train the biases as well, either only for LoRA layers or for all layers. This can be specified by passing the values "none", "lora_only" or "all" for the lora_train_bias parameter.

lora_model = convert_to_lora_model(
    model, lora_rank=2, lora_train_bias="lora_only"
)

Sequential and functional models

The current implementation focusses on models created by subclassing, which is the case for all tfimm models. In particular, the registry system works only for subclassed models. However, some of the functionality also works for functional models.

  • LoRA layers are the basic building blocks for both subclassed as well as functional models.

  • Transferring weights works for all models, regardless of type, provided the regular model and LoRA variant have the same architecture with the exception of LoRA layers. Use the tfimm.models.transfer_weights() function to tranfer weights to LoRA.

    from tfimm.architectures import lora
    from tfimm.models import transfer_weights
    
    # Transfer weights into the LoRA model
    transfer_weights(
        regular_model, lora_model, weights_to_ignore=lora.LORA_WEIGHT_NAMES
    )
    
  • After training, we need to manually merge weights and then transfer them back to the regular model.

    lora.merge_weights(lora_model)
    transfer_weights(lora_model, regular_model)
    
  • The functions lora_trainable_weights() and lora_non_trainable_weights() work for all models, regardless of type and return a list of weights to be used for LoRA training (or all other weights).

Interface

All functions are available under tfimm.architectures.lora.

Factory

convert_to_lora_model(model, **kwargs)[source]

Creates a LoRA version of a model.

Parameters:
  • model (Model) – Source model. Has to be an instance of a class that has a corresponding LoRA architecture registered.

  • **kwargs – LoRA parameters, such as lora_rank and lora_alpha need to be passed as kwargs and will be added to the model config.

Returns:

LoRA model.

Return type:

Model

convert_to_regular_model(model)[source]

Converts a LoRA model to a regular model.

Parameters:

model (Model) – LoRA model to be converted. Has be be a class that has been registered as a LoRA architecture.

Returns:

The converted model.

Return type:

Model

create_model(model_name, pretrained=False, model_path='', **kwargs)[source]

Creates a LoRA model from a tfimm model name.

Parameters:
  • model_name (str) – Name of model to instantiate.

  • pretrained (bool) – If True, load pretrained weights as specified by the url field in config. If url is [timm], the weights will be downloaded from timm and converted to TensorFlow. See tfimm.create_model() for details.

  • model_path (str) – Path of model weights to load after model is initialized. This takes precedence over pretrained.

  • **kwargs – LoRA parameters, such as lora_rank and lora_alpha need to be passed as kwargs and will be added to the model config.

Returns:

The created model.

Return type:

Model

merge_lora_weights(model)[source]

Recursively merge weights in all LoRA layers in the given model. The model is modified in place.

Parameters:

model (Model) – Model for merging weights.

lora_non_trainable_weights(model, train_bias='none', trainable_layers=None)[source]

Returns a list of non-trainable weights for the LoRA model. This function complements lora_trainable_weights().

Parameters:
  • model (Model) – A keras model.

  • train_bias (str) – If "none" or "all", no or all bias weights are trainable respectively. If "lora_only", only the bias weights of LoRA layers are set to trainable.

  • trainable_layers (List[str] | None) – A list of layer names that should be kept trainable.

Returns:

List of LoRA non-trainable weights.

Return type:

List[Variable]

lora_trainable_weights(model, train_bias='none', trainable_layers=None)[source]

Returns a list of variables to be used instead of model.trainable_weights when doing LoRA training.

Parameters:
  • model (Model) – A keras model.

  • train_bias (str) – If "none" or "all", no or all bias weights are trainable respectively. If "lora_only", only the bias weights of LoRA layers are set to trainable.

  • trainable_layers (List[str] | None) – A list of layer names that should be kept trainable.

Returns:

List of LoRA trainable weights.

Return type:

List[Variable]

Layers

class LoRAConv2D(*args, **kwargs)[source]

LoRA version of the Conv2D layer.

Parameters:
  • lora_rank (int) –

  • lora_alpha (float) –

class LoRADense(*args, **kwargs)[source]

LoRA version of the Dense layer.

Parameters:
  • lora_rank (int) –

  • lora_alpha (float) –

convert_to_lora_layer(layer, **kwargs)[source]

Convenience function to convert supported layer types to their LoRA counterparts.

Parameters:
  • layer (Layer) – Layer to be converted.

  • **kwargs – LoRA specific parameters such as lora_rank have to be passed as kwargs.

Returns:

LoRA layer instance.

Return type:

Layer

Registry

lora_architecture(model_cls)[source]

Returns the LoRA model class registered for a given base model class.

lora_base_architecture(lora_cls)[source]

Returns the base class corresponding to the given registered LoRA model class.

lora_config(model_cls)[source]

Returns the config class for the LoRA model associated with the given base model class.

register_lora_architecture(lora_cls=None, *, base_cls=None)[source]

Decorator to register a LoRA variant of a model architecture. It is used as follows

@register_lora_architecture
class LoRAResNet(ResNet):
    ...

This will associate the class LoRAResNet as the LoRA version of the class ResNet. If the LoRA architecture is not created via subclassing, it can be specified explicitely.

@register_lora_architecture(base_cls=ResNet)
class LoRAResNet(tf.keras.Model):
    ...

A model class can be its own LoRA variant, if the model can be created with regular or LoRA layers depending on the config. In that case this function needs to be invoked after the model has been defined.

class FlexibleModel(tf.keras.Model):
    ...

register_lora_architecture(FlexibleModel, base_cls=FlexibleModel)
Parameters:
  • lora_cls – LoRA model class. We assume that it has a cfg_class class attribute.

  • base_cls – Regular model class. If not provided we use the base class of lora_cls.