Segment Anything

We provide an implementation and pretrained weights for the Segment Anything Models.

Paper: Segment Anything. [arXiv:2304.02643].

Original pytorch code and weights from Facebook Research.

The following models are available.

  • sam_vit_b

  • sam_vit_l

  • sam_vit_h

In the code we are trying to follow this convention in comments and docstrings.

  • N is the batch dimension

  • (H0, W0) is the dimension of the input image to SAMPRedictor. There are no constraints to this size as the image will be resized and padded to the model input dimensions.

  • (H, W) is the model input size. This is (1024, 1024) for the pretrained models.

  • (H’, W’) is the input size for mask prompts and the output size for predicted mask logits. For the pretrained models this is (256, 256). To be precise, this is calculated as H’=4*H’’ and same for W’.

  • (H’’, W’’) is the spatial dimension of image embeddings. For pretrained models this is (64, 64). This is calculated as H’’=H/patch_size with patch_size=16.

  • C is the number of image input channels. Usually C=3.

  • M1 is the number of point prompts given to the model.

  • M2 is the number of box prompts given to the model. The PyTorch code only supports M2={0, 1}, so the accuracy with multiple box prompts might be limited.

  • M3 is the number of mask prompts given to the model. The PyTorch code only supports M3={0, 1}, so the accuracy with multiple mask prompts might be limited.

  • M is the number of tokens in the sparse embeddings returned by the prompt embedder. The number depends on M1 and M2.

  • D is the embedding dimension, which is shared by both image, sparse and dense prompt embeddings. For the pretrained models this is 256.

  • K is the number of masks returned by the model. This number is controlled by the model parameter nb_multimask_outputs (set to 3 in pretrained models). And also by the parameter multimask_output when calling SAMPredictor.

class SegmentAnythingModelConfig(name='', url='', in_channels=3, input_size=(1024, 1024), fixed_input_size=True, embed_dim=256, nb_multimask_outputs=3, mask_threshold=0.0, encoder_patch_size=16, encoder_embed_dim=768, encoder_nb_blocks=12, encoder_nb_heads=12, encoder_mlp_ratio=4.0, encoder_drop_rate=0.0, encoder_attn_drop_rate=0.0, encoder_drop_path_rate=0.0, encoder_norm_layer='layer_norm_eps_1e-6', encoder_act_layer='gelu', encoder_qkv_bias=True, encoder_global_attn_indices=(2, 5, 8, 11), encoder_window_size=14, prompt_mask_hidden_dim=16, decoder_nb_blocks=2, decoder_nb_heads=8, decoder_mlp_channels=2048, decoder_iou_head_depth=3, decoder_iou_hidden_dim=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), first_conv='image_encoder/patch_embed/proj')[source]

Configuration class for SAM models.

Parameters:
  • name (str) – Name of the model.

  • url (str) – URL for pretrained weights.

  • in_channels (int) – Number of input image channels.

  • input_size (Tuple[int, int]) – Input image size (height, width)

  • fixed_input_size (bool) – If True, the model only accepts inputs of size input_size. If False, the models accepts arbitrary input sizes by interpolating the positional encodings to account for the new input size.

  • embed_dim (int) – The shared embedding dimension of image and prompt embeddings.

  • nb_multimask_outputs (int) – Number of masks predicted by the model for each prompt.

  • mask_threshold (float) – Threshold for thresholding mask logits to a boolean mask.

  • encoder_patch_size (int) – Patchifying the image is implemented via a convolutional layer with kernel size and stride equal to patch_size.

  • encoder_embed_dim (int) – Feature dimensions at each stage. These are hidden feature dimensions. The output dimension (which has to be compatible with the prompt embedding dimension) is given by embed_dim.

  • encoder_nb_blocks (int) – Number of attention blocks in the image encoder.

  • encoder_nb_heads (int) – Number of self-attention heads in the image encoder.

  • encoder_mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim

  • encoder_drop_rate (float) – Dropout rate

  • encoder_attn_drop_rate (float) – Attention dropout rate

  • encoder_drop_path_rate (float) – Dropout rate for stochastic depth

  • encoder_norm_layer (str) – Normalization layer. See norm_layer_factory() for possible values.

  • encoder_act_layer (str) – Activation function. See act_layer_factory() for possible values.

  • encoder_qkv_bias (bool) – If True, add bias for qkv projection layers.

  • encoder_global_attn_indices (Tuple) – Indexes for blocks using global attention. All other blocks use window attention.

  • encoder_window_size (int) – Window size for window attention blocks.

  • prompt_mask_hidden_dim (int) – Hidden dimension in the mask encoder network.

  • decoder_nb_blocks (int) – Number of attention blocks in the mask decoder.

  • decoder_nb_heads (int) – Number of self-attention heads in the mask decoder.

  • decoder_mlp_channels (int) – Number of channels in mlp layers.

  • decoder_iou_head_depth (int) – Number of layers in score predictor network.

  • decoder_iou_hidden_dim (int) – Number of hidden dimensions in score predictor network.

  • mean (Tuple[float, float, float]) – Defines preprocessing function. If x is an image with pixel values in (0, 1), the preprocessing function is (x - mean) / std.

  • std (Tuple[float, float, float]) – Defines preprpocessing function.

  • first_conv (str) – Name of first convolutional layer. Used by create_model() to adapt the number in input channels when loading pretrained weights.

class SegmentAnythingModel(*args, **kwargs)[source]
Parameters:

cfg (SegmentAnythingModelConfig) –

call(inputs, training=False, multimask_output=False, return_logits=False)[source]

Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using SamPredictor is recommended over calling the model directly.

Parameters:
  • inputs

    A dictionary with the following entries

    • images: An (N, H, W, C) tensor of preprocessed input images.

    • points: An (N, M1, 2) tensor of point prompts with coordinates in pixel space, i.e., values between 0 and H or W.

    • labels: An (N, M1) tensor of labels for point prompts. 1 indicates a foreground point and 0 indicates a background point.

    • boxes: An (N, M2, 4) tensor of box prompts of form (left, top, right, bottom) with coordinates in pixel space.

    • masks: An (N, M3, H’, W’) tensor of mask inputs, where M3 is either 1 or 0 (no mask provided).

  • training – Training or inference phase?

  • multimask_output – If True, we return multiple nested masks for each prompt.

  • return_logits – If True, we don’t threshold the upscaled mask. This is useful if we want to resize the mask back to original image size first and then apply the threshold.

Returns:

  • Masks, an (N, K, H, W) bool tensor of binary masked predictions, where K is determined by the multimask_output parameter.

  • Scores, an (N, K) tensor with the model’s predictions of mask quality.

  • Logits, an (N, K, H’, W’) tensor with low resoulution logits, where usually H’=H/4 and W’=W/4. This can be passed as mask input to subsequent iterations of prediction.

property dummy_inputs[source]

Returns a (nested) tensor of the correct shape for inference.

grid_size(input_size=None)[source]

Compute the grid size of image embeddings for the model given an image input size.

Parameters:

input_size (Tuple[int, int] | None) – Image input size (H, W). If not provided use the input size from the model config.

Returns:

Spatial size of image embeddings (H’’, W’’).

Return type:

Tuple[int, int]

mask_size(input_size=None)[source]

Compute the size of (low res) masks for the model given an image input size.

Parameters:

input_size (Tuple[int, int] | None) – Image input size (H, W). If not provided use the input size from the model config.

Returns:

Spatial size of low resolution masks (H’, W’).

Return type:

Tuple[int, int]

property mask_threshold[source]

Threshold for thresholding logit masks to boolean masks.

class SAMPredictor(model, preprocessing=None)[source]

User-friendly interface to the Segment Anything model. Uses SAM to calculate the image embedding for an image, and then allows repeated, efficient mask prediction given prompts.

While internally TF is used for inference, the inputs and return values in this class are numpy arrays for ease of use.

Parameters:
  • model (SegmentAnythingModel) – The model used for mask prediction.

  • preprocessing (Callable | None) – Preprocessing function for the model. If not provided we will query tfimm using the model name.

__call__(points=None, labels=None, boxes=None, masks=None, multimask_output=True, return_logits=False)[source]

Predicts masks end-to-end for the given prompts. We assume that the image has already been set.

The original image size is (H0, W0). After resizing and padding the image size becomes (H, W) given by input_size (usually (1024, 1024)). Mask input and logit output will have shape (H’, W’) given by mask_size (usually H’=H/4).

One can use preprocess_masks to transform an input mask from (H0, W0) to (H’, W’).

Prompts can also be batched, i.e., have the shape (N, M1, 2) for points; (N, M1) for point labels; (N, M2, 4) for boxes; and (N, M3, H’, W’) for mask prompts. Note that in this case we number and type of prompts is the same for each batch element. The return values will have the same batch dimension, i.e., (N, K, H, W) for predicted masks, etc.

Parameters:
  • points (ndarray | None) – An (M1, 2) array of point prompts with coordinates in pixel space of the original image (H0, W0).

  • labels (ndarray | None) – An (M1,) array of labels for point prompts. 1 indicates a foreground point and 0 indicates a background point.

  • boxes (ndarray | None) – An (M2, 4) tensor of box prompts of form (left, top, right, bottom) with coordinates in pixel space of the original image (H0, W0).

  • masks (ndarray | None) – An (M3, H’, W’) tensor of mask inputs, where (H’, W’) is the mask size (usually H’=H/4).

  • multimask_output (bool) – If True, we return multiple nested masks for each prompt.

  • return_logits (bool) – If True, we don’t threshold the upscaled mask.

Returns:

  • Masks, an (K, H, W) bool array of binary masked predictions, where K is determined by the multimask_output parameter. It is either 1, if multimask_output=False or given by the nb_multimask_outputs parameter in the model configuration.

  • Scores, an (K,) array with the model’s predictions of mask quality.

  • Logits, an (K, H’, W’) array with low resoulution logits, where usually H’=H/4 and W’=W/4. This can be passed as mask input to subsequent iterations of prediction.

clear_image()[source]

Unsets the image and forgets the embedding.

preprocess_masks(mask)[source]

Preprocesses a mask from the pixel space of the original image (H0, W0), to the correct input size to the model. Note that the mask should be a mask of logits and not the thresholded version.

Parameters:

mask (ndarray) – An array of shape (M, H0, W0) or (N, M, H0, W0), where (H0, W0) is the original size of the input image.

Returns:

Preprocessed mask of shape (M, H’, W’) or (N, M, H’, W’) as given by

mask_size.

Return type:

ndarray

set_image(image)[source]

Calculates and stores the image embeddings for the provided image, allowing masks to be predicted much faster.

Parameters:

image (ndarray) – An array of shape (H, W, C) with pixel values in [0, 255]. The image can be any shape, and it will be resized and padded to the model input shape as necessary.

Returns:

Nothing. The image embedding and resizing information are stored in the class.

class ImageResizer(src_size, dst_size, pad_only=False)[source]

Utility class to resize images to the largest side that fits in a given shape while preserving the aspect ratio. It also provides methods to resize coordinates and bounding boxes and pad images.

Parameters:
  • src_size (Tuple[int, int]) – Size of image before resizing. The resize object is image specific, i.e., for each source image size it is recommended to create a new ImageResizer object.

  • dst_size (Tuple[int, int]) – The target size after resizing (and padding).

  • pad_only (bool) – If True, we don’t do any resizing and only pad the image to dst_size.

pad_image(image, channels_last=True)[source]

Apply zero padding to an image to size dst_size.

Parameters:
  • image (ndarray) – Image to be padded. Can be 3D or 4D tensor with channel dimension before or after spatial dimensions.

  • channels_last (bool) – If True, images are in HWC format and if False in CHW format.

Returns:

Zero padded image of size dst_size.

Return type:

ndarray

postprocess_mask(mask, threshold=None)[source]

Convert an upscaled segmentation mask from dst_size back to src_size by removing padding and unscaling.

Parameters:
  • mask (ndarray) – Segmentation mask, i.e., image with channels_first ordering, of size dst_size. Should be a logit-mask, i.e., before thresholding.

  • threshold (float | None) – Optionally, we can apply thresholding after resizing to obtain a boolean mask.

Returns:

Mask of size src_size.

Return type:

ndarray

scale_boxes(boxes)[source]

Scale bounding boxes by the same factor as the image.

Parameters:

boxes (ndarray) – Boxes to be scaled.

Returns:

Scaled boxes.

Return type:

ndarray

scale_image(image, channels_last=True)[source]

Applies aspect-ratio preserving scaling to an image.

Parameters:
  • image (ndarray) – Image to be resized. Can be a 3D array (H, W, C) or a 4D array (N, H, W, C). We also accept channels first ordering (used for segmentation masks, i.e., (C, H, W) or (N, C, H, W).

  • channels_last (bool) – If True, images are in HWC format and if False in CHW format.

Returns:

Resized image with spatial dimensions given by rescaled_size. The longest edge of the image will be equal to dst_size.

Return type:

ndarray

scale_points(points)[source]

Scale points by the same factor as the image.

Parameters:

points (ndarray) – Points to be scaled.

Returns:

Scaled points.

Return type:

ndarray

static scale_to_size(image, size, channels_last)[source]

Scales an image to a given size. In this method we ignore dst_size and do not attempt to preserve the aspect ratio.

Parameters:
  • image (ndarray) – Image to be resized. Can be a 3D array (H, W, C) or a 4D array (N, H, W, C). We also accept channels first ordering (used for segmentation masks, i.e., (C, H, W) or (N, C, H, W).

  • size (Tuple[int, int]) – Target size.

  • channels_last (bool) – If True, images are in HWC format and if False in CHW format.

Returns:

Resized image array.

Return type:

ndarray

unscale_image(image, channels_last=True)[source]

Reverses the scaling operation.

Parameters:
  • image (ndarray) – Image to be rescaled back to original size given by src_size. We assume that image has size rescaled_size, otherwise aspect ratio will not be preserved. Image can be 3D or 4D with channels before or after spatial dimension.

  • channels_last (bool) – If True, images are in HWC format and if False in CHW format.

Returns:

Resized image with size src_size.

Return type:

ndarray