Customize Losses
MMDetection为用户提供了不同的损失函数。但默认配置可能不适用于不同的数据集或模型,因此用户可能希望修改特定的损失以适应新情况。
本教程首先阐述了损失的计算 pipeline,然后给出了一些关于如何修改每个步骤的说明。这些修改可以分为调整和加权。
计算 loss 的 pipeline
同时输入预测值和目标值,损失函数将输入张量映射到最终的损失标量。该映射可分为五个步骤。
- 设置采样方法,对正负样本进行采样
- 通过损失核函数得到 element-wise 或者 kernel-wise 的 loss 值
- 用带权的 tensor 从元素层面上地加权
- 将 tensor 转为 scalar
- 对 scalar 加权
step1.设置采样方法
对于某些损失函数,需要采用抽样策略来避免正负样本之间的不平衡。例如,在 RPN 头中使用 CrossEntropyLoss 时,我们需要在 train_cfg中 设置 RandomSampler
train_cfg=dict(
rpn=dict(
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False))
对于其他一些具有正负采样平衡机制的损失,如 Focal Loss、GHMC 和 QualityFocalLoss,采样器就没有必要了。
调整一个损失与第2、4、5步更相关,而且大多数修改可以在配置中指定。这里我们以Focal Loss(FL)为例。下面的代码分别是 FL 的构造方法和配置,它们实际上是一对一的对应关系。
@LOSSES.register_module()
class FocalLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=1.0):
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0)
step2.调整超参数
gamma
和 beta
是 Focal Loss 中的两个超参数。例如,如果我们想把 gamma
的值改为 1.5,把 alpha
改为 0.5,那么我们可以在配置中指定它们,如下所示。
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=1.5,
alpha=0.5,
loss_weight=1.0)
step3.调整 reduction 的方式
FL 的默认 reduction 是平均值。假设我们想把 reduction 从求均值改为求和,我们可以在配置中指定如下。
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0,
reduction='sum')
加权损失是指我们对损失元素层面上的重新加权。更具体地说,我们将损失张量与一个具有相同形状的权重张量相乘。因此,损失的不同条目可以有不同的标量,所以称为 element-wisely 加权。损失权重在不同的模型中是不同的,而且与上下文高度相关,但总体上有两种损失权重,分类损失的 label_weights 和 bbox_weights 用于 bbox 回归损失。你可以在相应头部的 get_target 方法中找到它们。这里我们以 ATSSHead
为例,它继承了 AnchorHead
,但是覆盖了它的 get_targets 方法,产生了不同的 label_weights 和 bbox_weights。
class ATSSHead(AnchorHead):
...
def get_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):
step4.调整 loss 权重
这里的损失权重是一个标量,用于控制多任务学习中不同损失的权重,例如分类损失和回归损失。例如,如果我们想把分类损失的权重改为0.5,我们可以在配置中指定它,如下。
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=0.5)