CosineEMA¶
- class mmpretrain.models.utils.CosineEMA(model, momentum=0.004, end_momentum=0.0, interval=1, device=None, update_buffers=False)[source]¶
CosineEMA is implemented for updating momentum parameter, used in BYOL, MoCoV3, etc.
All parameters are updated by the formula as below:
\[X'_{t+1} = (1 - m) * X'_t + m * X_t\]Where \(m\) the the momentum parameter. And it’s updated with cosine annealing, including momentum adjustment following:
\[m = m_{end} + (m_{end} - m_{start}) * (\cos\frac{k\pi}{K} + 1) / 2\]where \(k\) is the current step, \(K\) is the total steps.
Note
This
momentumargument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, \(X'_{t}\) is the moving average and \(X_t\) is the new observed value. The value of momentum is usually a small number, allowing observed values to slowly update the ema parameters. See alsotorch.nn.BatchNorm2d.- Parameters:
model (nn.Module) – The model to be averaged.
momentum (float) – The start momentum value. Defaults to 0.004.
end_momentum (float) – The end momentum value for cosine annealing. Defaults to 0.
interval (int) – Interval between two updates. Defaults to 1.
device (torch.device, optional) – If provided, the averaged model will be stored on the
device. Defaults to None.update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.
- avg_func(averaged_param, source_param, steps)[source]¶
Compute the moving average of the parameters using the cosine momentum strategy.
- Parameters:
averaged_param (Tensor) – The averaged parameters.
source_param (Tensor) – The source parameters.
steps (int) – The number of times the parameters have been updated.
- Returns:
The averaged parameters.
- Return type:
Tensor