Source code for mmpretrain.models.backbones.edgenext
# Copyright (c) OpenMMLab. All rights reserved.
import math
from itertools import chain
from typing import Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, ModuleList, Sequential
from mmpretrain.registry import MODELS
from ..utils import (ChannelMultiheadAttention, PositionEncodingFourier,
                     build_norm_layer)
from .base_backbone import BaseBackbone
from .convnext import ConvNeXtBlock
class SDTAEncoder(BaseModule):
    """A PyTorch implementation of split depth-wise transpose attention (SDTA)
    encoder.
    Inspiration from
    https://github.com/mmaaz60/EdgeNeXt
    Args:
        in_channel (int): Number of input channels.
        drop_path_rate (float): Stochastic depth dropout rate.
            Defaults to 0.
        layer_scale_init_value (float): Initial value of layer scale.
            Defaults to 1e-6.
        mlp_ratio (int): Number of channels ratio in the MLP.
            Defaults to 4.
        use_pos_emb (bool): Whether to use position encoding.
            Defaults to True.
        num_heads (int): Number of heads in the multihead attention.
            Defaults to 8.
        qkv_bias (bool): Whether to use bias in the multihead attention.
            Defaults to True.
        attn_drop (float): Dropout rate of the attention.
            Defaults to 0.
        proj_drop (float): Dropout rate of the projection.
            Defaults to 0.
        layer_scale_init_value (float): Initial value of layer scale.
            Defaults to 1e-6.
        norm_cfg (dict): Dictionary to construct normalization layer.
            Defaults to ``dict(type='LN')``.
        act_cfg (dict): Dictionary to construct activation layer.
            Defaults to ``dict(type='GELU')``.
        scales (int): Number of scales. Default to 1.
    """
    def __init__(self,
                 in_channel,
                 drop_path_rate=0.,
                 layer_scale_init_value=1e-6,
                 mlp_ratio=4,
                 use_pos_emb=True,
                 num_heads=8,
                 qkv_bias=True,
                 attn_drop=0.,
                 proj_drop=0.,
                 norm_cfg=dict(type='LN'),
                 act_cfg=dict(type='GELU'),
                 scales=1,
                 init_cfg=None):
        super(SDTAEncoder, self).__init__(init_cfg=init_cfg)
        conv_channels = max(int(math.ceil(in_channel / scales)),
                            int(math.floor(in_channel // scales)))
        self.conv_channels = conv_channels
        self.num_convs = scales if scales == 1 else scales - 1
        self.conv_modules = ModuleList()
        for i in range(self.num_convs):
            self.conv_modules.append(
                nn.Conv2d(conv_channels,
                          conv_channels,
                          kernel_size=3,
                          padding=1,
                          groups=conv_channels))
        self.pos_embed = PositionEncodingFourier(
            embed_dims=in_channel) if use_pos_emb else None
        self.norm_csa = build_norm_layer(norm_cfg, in_channel)
        self.gamma_csa = nn.Parameter(
            layer_scale_init_value * torch.ones(in_channel),
            requires_grad=True) if layer_scale_init_value > 0 else None
        self.csa = ChannelMultiheadAttention(embed_dims=in_channel,
                                             num_heads=num_heads,
                                             qkv_bias=qkv_bias,
                                             attn_drop=attn_drop,
                                             proj_drop=proj_drop)
        self.norm = build_norm_layer(norm_cfg, in_channel)
        self.pointwise_conv1 = nn.Linear(in_channel, mlp_ratio * in_channel)
        self.act = MODELS.build(act_cfg)
        self.pointwise_conv2 = nn.Linear(mlp_ratio * in_channel, in_channel)
        self.gamma = nn.Parameter(
            layer_scale_init_value * torch.ones(in_channel),
            requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(
            drop_path_rate) if drop_path_rate > 0. else nn.Identity()
    def forward(self, x):
        shortcut = x
        spx = torch.split(x, self.conv_channels, dim=1)
        for i in range(self.num_convs):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.conv_modules[i](sp)
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        x = torch.cat((out, spx[self.num_convs]), 1)
        # Channel Self-attention
        B, C, H, W = x.shape
        x = x.reshape(B, C, H * W).permute(0, 2, 1)
        if self.pos_embed:
            pos_encoding = self.pos_embed((B, H, W))
            pos_encoding = pos_encoding.reshape(B, -1,
                                                x.shape[1]).permute(0, 2, 1)
            x += pos_encoding
        x = x + self.drop_path(self.gamma_csa * self.csa(self.norm_csa(x)))
        x = x.reshape(B, H, W, C)
        # Inverted Bottleneck
        x = self.norm(x)
        x = self.pointwise_conv1(x)
        x = self.act(x)
        x = self.pointwise_conv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (B, H, W, C) -> (B, C, H, W)
        x = shortcut + self.drop_path(x)
        return x
[docs]@MODELS.register_module()
class EdgeNeXt(BaseBackbone):
    """EdgeNeXt.
    A PyTorch implementation of: `EdgeNeXt: Efficiently Amalgamated
    CNN-Transformer Architecture for Mobile Vision Applications
    <https://arxiv.org/abs/2206.10589>`_
    Inspiration from
    https://github.com/mmaaz60/EdgeNeXt
    Args:
        arch (str | dict): The model's architecture. If string, it should be
            one of architectures in ``EdgeNeXt.arch_settings``.
            And if dict, it should include the following keys:
            - channels (list[int]): The number of channels at each stage.
            - depths (list[int]): The number of blocks at each stage.
            - num_heads (list[int]): The number of heads at each stage.
            Defaults to 'xxsmall'.
        in_channels (int): The number of input channels.
            Defaults to 3.
        global_blocks (list[int]): The number of global blocks.
            Defaults to [0, 1, 1, 1].
        global_block_type (list[str]): The type of global blocks.
            Defaults to ['None', 'SDTA', 'SDTA', 'SDTA'].
        drop_path_rate (float): Stochastic depth dropout rate.
            Defaults to 0.
        layer_scale_init_value (float): Initial value of layer scale.
            Defaults to 1e-6.
        linear_pw_conv (bool): Whether to use linear layer to do pointwise
            convolution. Defaults to False.
        mlp_ratio (int): The number of channel ratio in MLP layers.
            Defaults to 4.
        conv_kernel_size (list[int]): The kernel size of convolutional layers
            at each stage. Defaults to [3, 5, 7, 9].
        use_pos_embd_csa (list[bool]): Whether to use positional embedding in
            Channel Self-Attention. Defaults to [False, True, False, False].
        use_pos_emebd_global (bool): Whether to use positional embedding for
            whole network. Defaults to False.
        d2_scales (list[int]): The number of channel groups used for SDTA at
            each stage. Defaults to [2, 2, 3, 4].
        norm_cfg (dict): The config of normalization layer.
            Defaults to ``dict(type='LN2d', eps=1e-6)``.
        out_indices (Sequence | int): Output from which stages.
            Defaults to -1, means the last stage.
        frozen_stages (int): Stages to be frozen (all param fixed).
            Defaults to 0, which means not freezing any parameters.
        gap_before_final_norm (bool): Whether to globally average the feature
            map before the final norm layer. Defaults to True.
        act_cfg (dict): The config of activation layer.
            Defaults to ``dict(type='GELU')``.
        init_cfg (dict, optional): Config for initialization.
            Defaults to None.
    """
    arch_settings = {
        'xxsmall': {  # parameters: 1.3M
            'channels': [24, 48, 88, 168],
            'depths': [2, 2, 6, 2],
            'num_heads': [4, 4, 4, 4]
        },
        'xsmall': {  # parameters: 2.3M
            'channels': [32, 64, 100, 192],
            'depths': [3, 3, 9, 3],
            'num_heads': [4, 4, 4, 4]
        },
        'small': {  # parameters: 5.6M
            'channels': [48, 96, 160, 304],
            'depths': [3, 3, 9, 3],
            'num_heads': [8, 8, 8, 8]
        },
        'base': {  # parameters: 18.51M
            'channels': [80, 160, 288, 584],
            'depths': [3, 3, 9, 3],
            'num_heads': [8, 8, 8, 8]
        },
    }
    def __init__(self,
                 arch='xxsmall',
                 in_channels=3,
                 global_blocks=[0, 1, 1, 1],
                 global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
                 drop_path_rate=0.,
                 layer_scale_init_value=1e-6,
                 linear_pw_conv=True,
                 mlp_ratio=4,
                 conv_kernel_sizes=[3, 5, 7, 9],
                 use_pos_embd_csa=[False, True, False, False],
                 use_pos_embd_global=False,
                 d2_scales=[2, 2, 3, 4],
                 norm_cfg=dict(type='LN2d', eps=1e-6),
                 out_indices=-1,
                 frozen_stages=0,
                 gap_before_final_norm=True,
                 act_cfg=dict(type='GELU'),
                 init_cfg=None):
        super(EdgeNeXt, self).__init__(init_cfg=init_cfg)
        if isinstance(arch, str):
            arch = arch.lower()
            assert arch in self.arch_settings, \
                f'Arch {arch} is not in default archs ' \
                f'{set(self.arch_settings)}'
            self.arch_settings = self.arch_settings[arch]
        elif isinstance(arch, dict):
            essential_keys = {'channels', 'depths', 'num_heads'}
            assert isinstance(arch, dict) and set(arch) == essential_keys, \
                f'Custom arch needs a dict with keys {essential_keys}'
            self.arch_settings = arch
        self.channels = self.arch_settings['channels']
        self.depths = self.arch_settings['depths']
        self.num_heads = self.arch_settings['num_heads']
        self.num_layers = len(self.depths)
        self.use_pos_embd_global = use_pos_embd_global
        for g in global_block_type:
            assert g in ['None',
                         'SDTA'], f'Global block type {g} is not supported'
        self.num_stages = len(self.depths)
        if isinstance(out_indices, int):
            out_indices = [out_indices]
        assert isinstance(out_indices, Sequence), \
            f'"out_indices" must by a sequence or int, ' \
            f'get {type(out_indices)} instead.'
        for i, index in enumerate(out_indices):
            if index < 0:
                out_indices[i] = 4 + index
                assert out_indices[i] >= 0, f'Invalid out_indices {index}'
        self.out_indices = out_indices
        self.frozen_stages = frozen_stages
        self.gap_before_final_norm = gap_before_final_norm
        if self.use_pos_embd_global:
            self.pos_embed = PositionEncodingFourier(
                embed_dims=self.channels[0])
        else:
            self.pos_embed = None
        # stochastic depth decay rule
        dpr = [
            x.item()
            for x in torch.linspace(0, drop_path_rate, sum(self.depths))
        ]
        self.downsample_layers = ModuleList()
        stem = nn.Sequential(
            nn.Conv2d(in_channels, self.channels[0], kernel_size=4, stride=4),
            build_norm_layer(norm_cfg, self.channels[0]),
        )
        self.downsample_layers.append(stem)
        self.stages = ModuleList()
        block_idx = 0
        for i in range(self.num_stages):
            depth = self.depths[i]
            channels = self.channels[i]
            if i >= 1:
                downsample_layer = nn.Sequential(
                    build_norm_layer(norm_cfg, self.channels[i - 1]),
                    nn.Conv2d(
                        self.channels[i - 1],
                        channels,
                        kernel_size=2,
                        stride=2,
                    ))
                self.downsample_layers.append(downsample_layer)
            stage_blocks = []
            for j in range(depth):
                if j > depth - global_blocks[i] - 1:
                    stage_blocks.append(
                        SDTAEncoder(
                            in_channel=channels,
                            drop_path_rate=dpr[block_idx + j],
                            mlp_ratio=mlp_ratio,
                            scales=d2_scales[i],
                            use_pos_emb=use_pos_embd_csa[i],
                            num_heads=self.num_heads[i],
                        ))
                else:
                    dw_conv_cfg = dict(
                        kernel_size=conv_kernel_sizes[i],
                        padding=conv_kernel_sizes[i] // 2,
                    )
                    stage_blocks.append(
                        ConvNeXtBlock(
                            in_channels=channels,
                            dw_conv_cfg=dw_conv_cfg,
                            norm_cfg=norm_cfg,
                            act_cfg=act_cfg,
                            linear_pw_conv=linear_pw_conv,
                            drop_path_rate=dpr[block_idx + j],
                            layer_scale_init_value=layer_scale_init_value,
                        ))
            block_idx += depth
            stage_blocks = Sequential(*stage_blocks)
            self.stages.append(stage_blocks)
            if i in self.out_indices:
                out_norm_cfg = dict(type='LN') if self.gap_before_final_norm \
                    else norm_cfg
                norm_layer = build_norm_layer(out_norm_cfg, channels)
                self.add_module(f'norm{i}', norm_layer)
    def init_weights(self) -> None:
        # TODO: need to be implemented in the future
        return super().init_weights()
    def forward(self, x):
        outs = []
        for i, stage in enumerate(self.stages):
            x = self.downsample_layers[i](x)
            x = stage(x)
            if self.pos_embed and i == 0:
                B, _, H, W = x.shape
                x += self.pos_embed((B, H, W))
            if i in self.out_indices:
                norm_layer = getattr(self, f'norm{i}')
                if self.gap_before_final_norm:
                    gap = x.mean([-2, -1], keepdim=True)
                    outs.append(norm_layer(gap.flatten(1)))
                else:
                    # The output of LayerNorm2d may be discontiguous, which
                    # may cause some problem in the downstream tasks
                    outs.append(norm_layer(x).contiguous())
        return tuple(outs)
    def _freeze_stages(self):
        for i in range(self.frozen_stages):
            downsample_layer = self.downsample_layers[i]
            stage = self.stages[i]
            downsample_layer.eval()
            stage.eval()
            for param in chain(downsample_layer.parameters(),
                               stage.parameters()):
                param.requires_grad = False
    def train(self, mode=True):
        super(EdgeNeXt, self).train(mode)
        self._freeze_stages()