Transformer Models
Here we describe some features that are common to most transformer models.
Changing input shape
Most parts of a transformer architecture are independent of input resolution. Changing the input resolution results in a different number of patches. The projection, self-attention and MLP layers work on arbitrary length inputs. The only part that needs to be adapted are the position embeddings.
Position embeddings can adjusted via 2D interpolation to the new input resolution.
However, since position embeddings are learnt, after interpolation they may no longer
be meaningful. Thus, by default, transformer models can only run inference at the
resolution specified by input_size.
If we want to fine-tune a model at a different resolution, we can specify the new resolution when creating the model. In that case, the position embeddings will be interpolated for the new resolution.
# Default model with `input_size=(224, 224)`
model_224 = create_model("vit_tiny_patch16_224")
# Model with interpolated position embeddings
model_384 = create_model("vit_tiny_patch16_224", input_size=(384, 384))
Transforming weights
Internally adjusting model input size is done via the transform_weights field in
the config. The field transform_weights is a dictionary of the form
cfg.transform_weights = {"pos_embed": ViT.transform}
The function ViT.transform is called as src_model.transform(tgt_cfg) and returns
the corresponding weight transformed for tgt_cfg.
Inference at arbitrary resolution
We can enable inference at arbitrary resolution, by setting the parameter
interpolate_input=True when constructing the model.
model = create_model("vit_tiny_patch16_224", interpolate_input=True)
logits = model(np.zeros((1, 256, 256, 3), dtype="float32"))
To avoid accidental inference at the wrong resolution, the default is
interpolate_input=False.