import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from torch import nn
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse as ifilterfalse
[docs]class TorchDiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True,
per_image=False, logits=False):
super().__init__()
self.size_average = size_average
self.register_buffer('weight', weight)
self.per_image = per_image
self.logits = logits
def forward(self, input, target):
if self.logits:
input = torch.sigmoid(input)
return soft_dice_loss(input, target, per_image=self.per_image)
[docs]class TorchFocalLoss(nn.Module):
"""Implementation of Focal Loss[1]_ modified from Catalyst [2]_ .
Arguments
---------
gamma : :class:`int` or :class:`float`
Focusing parameter. See [1]_ .
alpha : :class:`int` or :class:`float`
Normalization factor. See [1]_ .
References
----------
.. [1] https://arxiv.org/pdf/1708.02002.pdf
.. [2] https://catalyst-team.github.io/catalyst/
"""
def __init__(self, gamma=2, reduce=True, logits=False):
super().__init__()
self.gamma = gamma
self.reduce = reduce
self.logits = logits
# TODO refactor
[docs] def forward(self, outputs, targets):
"""Calculate the loss function between `outputs` and `targets`.
Arguments
---------
outputs : :class:`torch.Tensor`
The output tensor from a model.
targets : :class:`torch.Tensor`
The training target.
Returns
-------
loss : :class:`torch.Variable`
The loss value.
"""
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(outputs, targets,
reduction='none')
else:
BCE_loss = F.binary_cross_entropy(outputs, targets,
reduction='none')
pt = torch.exp(-BCE_loss)
F_loss = (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
# def forward(self, outputs, targets):
# """Calculate the loss function between `outputs` and `targets`.
#
# Arguments
# ---------
# outputs : :class:`torch.Tensor`
# The output tensor from a model.
# targets : :class:`torch.Tensor`
# The training target.
#
# Returns
# -------
# loss : :class:`torch.Variable`
# The loss value.
# """
# if targets.size() != outputs.size():
# raise ValueError(
# f"Targets and inputs must be same size. "
# f"Got ({targets.size()}) and ({outputs.size()})"
# )
#
# max_val = (-outputs).clamp(min=0)
# log_ = ((-max_val).exp() + (-outputs - max_val).exp()).log()
# loss = outputs - outputs * targets + max_val + log_
#
# invprobs = F.logsigmoid(-outputs * (targets * 2.0 - 1.0))
# loss = self.alpha*(invprobs * self.gamma).exp() * loss
#
# return loss.sum(dim=-1).mean()
[docs]def torch_lovasz_hinge(logits, labels, per_image=False, ignore=None):
"""Lovasz Hinge Loss. Implementation edited from Maxim Berman's GitHub.
References
----------
https://github.com/bermanmaxim/LovaszSoftmax/
https://arxiv.org/abs/1705.08790
Arguments
---------
logits: :class:`torch.Variable`
logits at each pixel (between -inf and +inf)
labels: :class:`torch.Tensor`
binary ground truth masks (0 or 1)
per_image: bool, optional
compute the loss per image instead of per batch. Defaults to ``False``.
ignore: optional void class id.
Returns
-------
loss : :class:`torch.Variable`
Lovasz loss value for the input logits and labels. Compatible with
``loss.backward()`` as its a :class:`torch.Variable` .
"""
# TODO: Restructure into a class like TorchFocalLoss for compatibility
if per_image:
loss = mean(
lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0),
lab.unsqueeze(0),
ignore))
for log, lab in zip(logits, labels))
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits,
labels,
ignore))
return loss
[docs]def lovasz_hinge_flat(logits, labels):
"""Binary Lovasz hinge loss.
Arguments
---------
logits: :class:`torch.Variable`
Logits at each prediction (between -inf and +inf)
labels: :class:`torch.Tensor`
binary ground truth labels (0 or 1)
Returns
-------
loss : :class:`torch.Variable`
Lovasz loss value for the input logits and labels.
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * Variable(signs))
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
return loss
[docs]def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)
if ignore is None:
return scores, labels
valid = (labels != ignore)
vscores = scores[valid]
vlabels = labels[valid]
return vscores, vlabels
[docs]class TorchJaccardLoss(torch.nn.modules.Module):
# modified from XD_XD's implementation
def __init__(self):
super(TorchJaccardLoss, self).__init__()
def forward(self, outputs, targets):
eps = 1e-15
jaccard_target = (targets == 1).float()
jaccard_output = torch.sigmoid(outputs)
intersection = (jaccard_output * jaccard_target).sum()
union = jaccard_output.sum() + jaccard_target.sum()
jaccard_score = ((intersection + eps) / (union - intersection + eps))
self._stash_jaccard = jaccard_score
loss = 1. - jaccard_score
return loss
[docs]class TorchStableBCELoss(torch.nn.modules.Module):
def __init__(self):
super(TorchStableBCELoss, self).__init__()
def forward(self, input, target):
neg_abs = - input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
[docs]def binary_xloss(logits, labels, ignore=None):
"""
Binary Cross entropy loss
logits: [B, H, W] Variable, logits at each pixel (between -inf and +inf)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
ignore: void class id
"""
logits, labels = flatten_binary_scores(logits, labels, ignore)
loss = TorchStableBCELoss()(logits, Variable(labels.float()))
return loss
[docs]def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1 - pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
[docs]def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
"""
IoU for foreground class
binary: 1 foreground, 0 background
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
intersection = ((label == 1) & (pred == 1)).sum()
union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
if not union:
iou = EMPTY
else:
iou = float(intersection) / float(union)
ious.append(iou)
iou = mean(ious) # mean accross images if per_image
return 100 * iou
[docs]def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
"""
Array of IoU for each (non ignored) class
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
iou = []
for i in range(C):
if i != ignore:
intersection = ((label == i) & (pred == i)).sum()
union = ((label == i) | ((pred == i) & (label != ignore))).sum()
if not union:
iou.append(EMPTY)
else:
iou.append(float(intersection) / float(union))
ious.append(iou)
ious = [mean(iou) for iou in zip(*ious)] # mean across images if per_image
return 100 * np.array(ious)
# helper functions
def isnan(x):
return x != x
[docs]def mean(l, ignore_nan=False, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
if ignore_nan:
l = ifilterfalse(isnan, l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n
def dice_round(preds, trues):
preds = preds.float()
return soft_dice_loss(preds, trues)
def soft_dice_loss(outputs, targets, per_image=False):
batch_size = outputs.size()[0]
eps = 1e-5
if not per_image:
batch_size = 1
dice_target = targets.contiguous().view(batch_size, -1).float()
dice_output = outputs.contiguous().view(batch_size, -1)
intersection = torch.sum(dice_output * dice_target, dim=1)
union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) + eps
loss = (1 - (2 * intersection + eps) / union).mean()
return loss
torch_losses = {
'l1loss': nn.L1Loss,
'l1': nn.L1Loss,
'mae': nn.L1Loss,
'mean_absolute_error': nn.L1Loss,
'smoothl1loss': nn.SmoothL1Loss,
'smoothl1': nn.SmoothL1Loss,
'mean_squared_error': nn.MSELoss,
'mse': nn.MSELoss,
'mseloss': nn.MSELoss,
'categorical_crossentropy': nn.CrossEntropyLoss,
'cce': nn.CrossEntropyLoss,
'crossentropyloss': nn.CrossEntropyLoss,
'negative_log_likelihood': nn.NLLLoss,
'nll': nn.NLLLoss,
'nllloss': nn.NLLLoss,
'poisson_negative_log_likelihood': nn.PoissonNLLLoss,
'poisson_nll': nn.PoissonNLLLoss,
'poissonnll': nn.PoissonNLLLoss,
'kullback_leibler_divergence': nn.KLDivLoss,
'kld': nn.KLDivLoss,
'kldivloss': nn.KLDivLoss,
'binary_crossentropy': nn.BCELoss,
'bce': nn.BCELoss,
'bceloss': nn.BCELoss,
'bcewithlogits': nn.BCEWithLogitsLoss,
'bcewithlogitsloss': nn.BCEWithLogitsLoss,
'hinge': nn.HingeEmbeddingLoss,
'hingeembeddingloss': nn.HingeEmbeddingLoss,
'multiclass_hinge': nn.MultiMarginLoss,
'multimarginloss': nn.MultiMarginLoss,
'softmarginloss': nn.SoftMarginLoss,
'softmargin': nn.SoftMarginLoss,
'multiclass_softmargin': nn.MultiLabelSoftMarginLoss,
'multilabelsoftmarginloss': nn.MultiLabelSoftMarginLoss,
'cosine': nn.CosineEmbeddingLoss,
'cosineloss': nn.CosineEmbeddingLoss,
'cosineembeddingloss': nn.CosineEmbeddingLoss,
'lovaszhinge': torch_lovasz_hinge,
'focalloss': TorchFocalLoss,
'focal': TorchFocalLoss,
'jaccard': TorchJaccardLoss,
'jaccardloss': TorchJaccardLoss,
'dice': TorchDiceLoss,
'diceloss': TorchDiceLoss
}