import torch
import torch.nn.functional as F
import torch.nn as nn

from einops import rearrange
from . import utils


def dice_with_logits(input: torch.Tensor, target: torch.Tensor, dim=(-2, -1), reduction='mean', eps=1e-6):
    input = input.sigmoid()
    numerator = 2 * (input * target).sum(dim)
    denominator = input.pow(2).sum(dim) + target.pow(2).sum(dim)
    dice_losses = -numerator/denominator.clamp(eps)
    if reduction == 'sum':
        loss = dice_losses.sum()
    elif reduction == 'mean':
        loss = dice_losses.mean()
    elif reduction == 'none':
        loss = dice_losses
    else:
        raise NotImplementedError
    return loss


class PositionalEmbeddingCosineLoss(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        # input.shape = NTAC
        # target.shape = NTAC
        N, T, A, C = input.shape
        assert target.shape == (N, T, A, C)
        # Maximize cosine similarity
        loss = -F.cosine_similarity(input, target, dim=-1).sum().div(N * T * A)
        return loss


class ReconAccuracy(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        return input.gt(0.).float().eq(target).float().mean()


class REF1Score(nn.Module):
    def __init__(self, epsilon=1e-7, flatten=True, dim=None, with_logits=True):
        super().__init__()
        self.epsilon = epsilon
        self.flatten = flatten
        self.dim = dim
        self.with_logits = with_logits

    def forward(self, input: torch.Tensor, target: torch.Tensor):
        # input.shape = target.shape = NTAChw
        if self.flatten:
            input = input.flatten()
            target = target.flatten()
            dim = 0
        else:
            assert self.dim is not None
            dim = self.dim
        if self.with_logits:
            input = input.sigmoid()
        # Compute {t,f}{p,n}
        tp = (target * input).sum(dim=dim).to(torch.float32)
        fp = ((1 - target) * input).sum(dim=dim).to(torch.float32)
        fn = (target * (1 - input)).sum(dim=dim).to(torch.float32)
        # Compute precision and recall
        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)
        # Compute F1 score
        f1 = 2 * (precision * recall) / (precision + recall + self.epsilon)
        f1 = f1.clamp(min=self.epsilon, max=1 - self.epsilon)
        if self.flatten:
            return f1.mean()
        else:
            return f1


class FocalRegressionLoss(nn.Module):
    def __init__(self, base_loss, cutoff=0., maxout=1., gamma=1., detach_weights=True, reduction='mean'):
        super(FocalRegressionLoss, self).__init__()
        assert maxout > cutoff
        assert reduction in ['mean', 'sum', 'none']
        self.base_loss = base_loss
        self.cutoff = cutoff
        self.maxout = maxout
        self.gamma = gamma
        self.detach_weights = detach_weights
        self.reduction = reduction

    def forward(self, input: torch.Tensor, target: torch.Tensor):
        # Compute the base-loss (with no reduction)
        base_loss = self.base_loss(input, target)
        weights = ((torch.clamp(base_loss, self.cutoff, self.maxout) - self.cutoff) /
                   (self.maxout - self.cutoff))
        weights = weights.pow(self.gamma) if self.gamma != 1. else weights
        if self.detach_weights:
            weights = weights.detach()
        weighted_loss = base_loss * weights
        if self.reduction == 'mean':
            return weighted_loss.mean()
        elif self.reduction == 'sum':
            return weighted_loss.sum()
        elif self.reduction == 'none':
            return weighted_loss
        else:
            raise NotImplementedError


class FocalMSELoss(FocalRegressionLoss):
    def __init__(self, cutoff=0., maxout=1., gamma=1., detach_weights=True, reduction='mean'):
        super(FocalMSELoss, self).__init__(base_loss=nn.MSELoss(reduction='none'), cutoff=cutoff, maxout=maxout,
                                           gamma=gamma, detach_weights=detach_weights, reduction=reduction)


class FocalHuberLoss(FocalRegressionLoss):
    def __init__(self, cutoff=0., maxout=1., gamma=1., detach_weights=True, reduction='mean'):
        super(FocalHuberLoss, self).__init__(base_loss=nn.SmoothL1Loss(reduction='none'), cutoff=cutoff, maxout=maxout,
                                             gamma=gamma, detach_weights=detach_weights, reduction=reduction)


class FocalCrossEntropyLoss(nn.Module):
    def __init__(self, gamma=0., detach_weights=True, reduction='mean'):
        super(FocalCrossEntropyLoss, self).__init__()
        self.gamma = gamma
        self.detach_weights = detach_weights
        self.reduction = reduction
        self._ce = nn.CrossEntropyLoss(reduction=reduction)

    def forward(self, input: torch.Tensor, target: torch.Tensor):
        # Faster code path if gamma = 0 (i.e. no focal loss), which is the case by default
        if self.gamma == 0.:
            return self._ce(input, target)
        # Gamma != 0, so focus
        if input.dim() > 2:
            N, C, *S = input.shape
            assert tuple(target.shape) == ((N,) + tuple(S))
            input = rearrange(input.reshape(N, C, -1), 'n c s -> (n s) c')
            target = rearrange(target.reshape(N, -1), 'n s -> (n s)')
        target = target.view(-1, 1)
        # Implement CCE manually...
        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.exp()
        # noinspection PyTypeChecker
        weights = (1 - pt)**self.gamma
        if self.detach_weights:
            # noinspection PyUnresolvedReferences
            weights = weights.detach()
        loss = -1 * weights * logpt
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            raise NotImplementedError


class FocalBCEWithLogitsLoss(nn.Module):
    def __init__(self, gamma=0., detach_weights=True, reduction='mean'):
        super(FocalBCEWithLogitsLoss, self).__init__()
        self._bce = nn.BCEWithLogitsLoss(reduction='none')
        self.detach_weights = detach_weights
        self.reduction = reduction
        self.gamma = gamma

    def forward(self, input: torch.Tensor, target: torch.Tensor):
        bce = self._bce(input, target)
        if self.gamma != 0.:
            # noinspection PyTypeChecker
            weights = (1 - torch.exp(-bce)) ** self.gamma
            if self.detach_weights:
                # noinspection PyUnresolvedReferences
                weights = weights.detach()
            # noinspection PyTypeChecker
            bce = weights ** self.gamma * bce
        if self.reduction == 'mean':
            return bce.mean()
        elif self.reduction == 'sum':
            return bce.sum()
        elif self.reduction == 'none':
            return bce
        else:
            raise NotImplementedError


class SC2UnitTypeLoss(FocalCrossEntropyLoss):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        # input.shape = target.shape = NTACrΘ
        # The plan is to first compute an additional target channel that marks the absence of a unit (i.e. it's 1 if
        # no unit is present). Second, we concatenate the input with a zero channel, i.e. the no-unit channel is
        # always 0. This fixes the global shift of the logits, which the softmax is otherwise invariant to.
        N, T, A, _, R, TH = input.shape
        _rearr = lambda x: rearrange(x, 'n t a c h w -> (n t a) c h w')
        # noinspection PyTypeChecker
        target = utils.add_null_class_to_onehot_map(target, dim=3)
        input = utils.add_zero_channel_to_logit_map(input, dim=3)
        return super(SC2UnitTypeLoss, self).forward(_rearr(input), torch.argmax(_rearr(target), 1))


class SC2ReconLoss(nn.Module):
    NAIVE = False

    def __init__(self, loss_on_hecs=nn.MSELoss, loss_on_friendly_markers=nn.BCEWithLogitsLoss,
                 loss_on_unit_types=SC2UnitTypeLoss, loss_on_terrain=nn.MSELoss,
                 loss_on_spatial_markers=nn.MSELoss):
        super(SC2ReconLoss, self).__init__()
        # Init losses
        self.loss_on_hecs = self.parse_loss(loss_on_hecs)
        self.loss_on_friendly_markers = self.parse_loss(loss_on_friendly_markers)
        self.loss_on_unit_types = self.parse_loss(loss_on_unit_types)
        self.loss_on_terrain = self.parse_loss(loss_on_terrain)
        self.loss_on_spatial_markers = self.parse_loss(loss_on_spatial_markers)

    def parse_loss(self, loss):
        no_loss = lambda *_, **__: torch.tensor(0.)
        loss_kwargs = {'reduction': 'sum'}
        if isinstance(loss, dict):
            assert 'name' in loss
            loss_name = loss['name']
            if loss_name in globals():
                loss_cls = globals()[loss_name]
            else:
                loss_cls = getattr(torch.nn, loss_name)
            loss_kwargs.update(loss.get('kwargs', {}))
        elif isinstance(loss, type):
            loss_cls = loss
        elif loss is None:
            loss_cls = None
        else:
            raise ValueError
        return loss_cls(**loss_kwargs) if loss_cls is not None else no_loss

    def forward(self, input: torch.Tensor, target: torch.Tensor):
        if self.NAIVE:
            raise NotImplementedError
        from .envs.sc2.sc2_trajectories import SC2Trajectories
        N, T, A, *_ = input.shape
        # split input and target
        input_components = SC2Trajectories.split_state_tensor(input)
        target_components = SC2Trajectories.split_state_tensor(target)
        # Apply the losses
        hecs_loss = self.loss_on_hecs(input_components.hecs, target_components.hecs).div(N * T * A)
        friendly_marker_loss = self.loss_on_friendly_markers(input_components.friendly_marker,
                                                             target_components.friendly_marker).div(N * T * A)
        unit_type_loss = self.loss_on_unit_types(input_components.unit_types,
                                                 target_components.unit_types).div(N * T * A)
        terrain_loss = self.loss_on_terrain(input_components.terrain, target_components.terrain).div(N * T * A)
        spatial_marker_loss = self.loss_on_spatial_markers(input_components.spatial_markers,
                                                           target_components.spatial_markers).div(N * T * A)
        # Evaluate total loss
        total_loss = hecs_loss + friendly_marker_loss + unit_type_loss + terrain_loss + spatial_marker_loss
        return total_loss


class SC2HECSReconScore(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        from .envs.sc2.sc2_trajectories import SC2Trajectories
        input_components = SC2Trajectories.split_state_tensor(input)
        target_components = SC2Trajectories.split_state_tensor(target)
        return -F.mse_loss(input_components.hecs, target_components.hecs)


class SC2UnitTypeAccuracy(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        from .envs.sc2.sc2_trajectories import SC2Trajectories
        input_components = SC2Trajectories.split_state_tensor(input)
        target_components = SC2Trajectories.split_state_tensor(target)
        # Add in the zero channel
        input_unit_types = utils.add_zero_channel_to_logit_map(input_components.unit_types, dim=3)
        target_unit_types = utils.add_null_class_to_onehot_map(target_components.unit_types, dim=3)
        acc = torch.argmax(input_unit_types, dim=3).eq(torch.argmax(target_unit_types, dim=3)).float().mean()
        return acc


class SC2UnitTypeConfusionMatrix(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        from .envs.sc2.sc2_trajectories import SC2Trajectories
        input_components = SC2Trajectories.split_state_tensor(input)
        target_components = SC2Trajectories.split_state_tensor(target)
        # Add in the zero channel
        input_unit_types = torch.argmax(utils.add_zero_channel_to_logit_map(input_components.unit_types, dim=3), dim=3)
        target_unit_types = torch.argmax(utils.add_null_class_to_onehot_map(target_components.unit_types, dim=3), dim=3)
        cm = utils.confusion_matrix(input_unit_types, target_unit_types,
                                    num_classes=input_components.unit_types.shape[3] + 1)
        return cm


class SC2TerrainReconScore(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        from .envs.sc2.sc2_trajectories import SC2Trajectories
        input_components = SC2Trajectories.split_state_tensor(input)
        target_components = SC2Trajectories.split_state_tensor(target)
        return -F.mse_loss(input_components.terrain, target_components.terrain)


class SC2FriendlyMarkerF1Score(REF1Score):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        from .envs.sc2.sc2_trajectories import SC2Trajectories
        input_components = SC2Trajectories.split_state_tensor(input)
        target_components = SC2Trajectories.split_state_tensor(target)
        return super(SC2FriendlyMarkerF1Score, self).forward(input_components.friendly_marker,
                                                             target_components.friendly_marker)


class SC2FriendlyMarkerConfusionMatrix(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        from .envs.sc2.sc2_trajectories import SC2Trajectories
        input_components = SC2Trajectories.split_state_tensor(input)
        target_components = SC2Trajectories.split_state_tensor(target)
        cm = utils.confusion_matrix(input_components.friendly_marker.gt(0.), target_components.friendly_marker,
                                    num_classes=2)
        return cm


class BBConfusionMatrix(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        cm = utils.confusion_matrix(input.gt(0.), target, num_classes=2)
        return cm
