Skip to main content

Customize Losses

MMDetection为用户提供了不同的损失函数。但默认配置可能不适用于不同的数据集或模型,因此用户可能希望修改特定的损失以适应新情况。

本教程首先阐述了损失的计算 pipeline,然后给出了一些关于如何修改每个步骤的说明。这些修改可以分为调整和加权。

计算 loss 的 pipeline

同时输入预测值和目标值,损失函数将输入张量映射到最终的损失标量。该映射可分为五个步骤。

  1. 设置采样方法,对正负样本进行采样
  2. 通过损失核函数得到 element-wise 或者 kernel-wise 的 loss 值
  3. 用带权的 tensor 从元素层面上地加权
  4. 将 tensor 转为 scalar
  5. 对 scalar 加权

step1.设置采样方法

对于某些损失函数,需要采用抽样策略来避免正负样本之间的不平衡。例如,在 RPN 头中使用 CrossEntropyLoss 时,我们需要在 train_cfg中 设置 RandomSampler

train_cfg=dict(
rpn=dict(
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False))

对于其他一些具有正负采样平衡机制的损失,如 Focal Loss、GHMC 和 QualityFocalLoss,采样器就没有必要了。

调整 loss

调整一个损失与第2、4、5步更相关,而且大多数修改可以在配置中指定。这里我们以Focal Loss(FL)为例。下面的代码分别是 FL 的构造方法和配置,它们实际上是一对一的对应关系。

@LOSSES.register_module()
class FocalLoss(nn.Module):

def __init__(self,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=1.0):
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0)

step2.调整超参数

gammabeta 是 Focal Loss 中的两个超参数。例如,如果我们想把 gamma 的值改为 1.5,把 alpha 改为 0.5,那么我们可以在配置中指定它们,如下所示。

loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=1.5,
alpha=0.5,
loss_weight=1.0)

step3.调整 reduction 的方式

FL 的默认 reduction 是平均值。假设我们想把 reduction 从求均值改为求和,我们可以在配置中指定如下。

loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0,
reduction='sum')
info

加权损失是指我们对损失元素层面上的重新加权。更具体地说,我们将损失张量与一个具有相同形状的权重张量相乘。因此,损失的不同条目可以有不同的标量,所以称为 element-wisely 加权。损失权重在不同的模型中是不同的,而且与上下文高度相关,但总体上有两种损失权重,分类损失的 label_weights 和 bbox_weights 用于 bbox 回归损失。你可以在相应头部的 get_target 方法中找到它们。这里我们以 ATSSHead 为例,它继承了 AnchorHead,但是覆盖了它的 get_targets 方法,产生了不同的 label_weights 和 bbox_weights。

class ATSSHead(AnchorHead):

...

def get_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):

step4.调整 loss 权重

这里的损失权重是一个标量,用于控制多任务学习中不同损失的权重,例如分类损失和回归损失。例如,如果我们想把分类损失的权重改为0.5,我们可以在配置中指定它,如下。

loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=0.5)