#!/usr/bin/env python3

from src import const
from torch import nn
import torch


class ContrastiveLoss(nn.Module):
    def __init__(self, get_contrastive_cams_fn, debug=False, is_label_mask=False, target_probability_swd=const.LAMBDAS[0],
                 multilabel=False, pos_weight=None, pos_only=False):
        super().__init__()

        self.ce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none' if pos_weight is None else 'mean') if multilabel else nn.CrossEntropyLoss(label_smoothing=const.LABEL_SMOOTHING)
        self.get_contrastive_cams = get_contrastive_cams_fn
        self.divergence = const.LAMBDAS[-1] != 0
        self.is_label_mask = is_label_mask
        self.multilabel = multilabel
        self.pos_only = pos_only

    def kld(self, cc, fg_mask, y):
        fg_mask_probs = (fg_mask * const.LAMBDAS[1]).view(*cc.shape[:-2], -1).to(torch.float).softmax(dim=-1).view(cc.shape)
        cc_log_probs = (cc * const.LAMBDAS[0]).view(*cc.shape[:-2], -1).softmax(dim=-1).clamp(min=1E-6).view(cc.shape).log()
        fg_mask_log_probs = fg_mask_probs.log()

        divergence = fg_mask_probs * (fg_mask_log_probs - cc_log_probs)
        divergence.nan_to_num_()

        if not self.multilabel:
            divergence = divergence.view(-1, *cc.shape[2:]).index_fill(0, y.argmax(1) + (torch.arange(divergence.size(0)).to(const.DEVICE) * divergence.size(1)), torch.tensor(0.))
            return divergence.sum() / (divergence.size(0) - cc.size(0))
        else: return divergence.sum() / divergence.size(0)

    def forward(self, y_pred, y):
        if self.multilabel:
            labels = ((torch.arange(const.N_CLASSES) + 1) * torch.ones(*const.CAM_SIZE, const.N_CLASSES)).T[None,].repeat(y[0].size(0), 1, 1, 1).to(const.DEVICE)
            fg_mask = (labels == y[0].repeat(1, const.N_CLASSES, 1).view(y[0].size(0), -1, *y[0].shape[1:])).to(torch.int)
            fg_mask[y[1].to(torch.bool) & (fg_mask.sum(2).sum(2) == 0)] = 1
            ablation = (fg_mask * y_pred[1] - (1 - fg_mask) * y_pred[1].abs()).sum(dim=[2, 3]) * y[1] + y_pred[1].sum(dim=[2, 3]) * (1 - y[1])
        elif self.is_label_mask:
            cc = self.get_contrastive_cams(y[1], y_pred[1]).to(const.DEVICE)

            labels = y[1].argmax(1)
            fg_mask = torch.cat([(c == y).to(torch.int)[None,] for c, y in zip(y[0], labels)]).repeat((cc.shape[1], 1, 1, 1)).permute(1, 0, 2, 3).to(const.DEVICE)

            mixup_mask = 1 - fg_mask[:, 0] - (y[0] == -1).to(torch.int)
            mixed_idx = (mixup_mask * (y[0] + 1)).to(torch.int).flatten(start_dim=1).max(dim=1).values

            mixup_sparse_mask = torch.zeros(cc.shape, device=const.DEVICE, dtype=torch.int)
            mixup_sparse_mask.view(-1, *const.CAM_SIZE)[mixed_idx[mixed_idx.nonzero().flatten()] - 1 + mixed_idx.nonzero().flatten() * const.N_CLASSES] = mixup_mask[mixed_idx.nonzero().flatten()]

            ablation = (-cc * fg_mask + cc.abs() * (1 - fg_mask) + (-cc.abs() + cc) * mixup_sparse_mask).sum(dim=[2, 3])
        else:
            cc = self.get_contrastive_cams(y[1], y_pred[1]).to(const.DEVICE)

            y[0][y[0].sum(1).sum(1) == 0] = 1
            fg_mask = y[0].unsqueeze(1).repeat(1, const.N_CLASSES, 1, 1).to(torch.int)
            ablation = (-cc * fg_mask + cc.abs() * (1 - fg_mask)).sum(dim=[2, 3])

        ace = self.ce(ablation, y[1])
        if self.multilabel and self.ce.pos_weight is None: ace = (ace[y[1] == 0].mean() + ace[y[1] == 1].mean()) / 2

        if self.divergence:
            if self.multilabel:
                if self.pos_only:
                    target_idx = y[1].flatten().nonzero()
                    cc = y_pred[1].view(-1, *y_pred[1].shape[-2:])[target_idx][:, 0]
                    fg_mask = fg_mask.view(-1, *fg_mask.shape[-2:])[target_idx][:, 0]
                else:
                    cc = y_pred[1].view(-1, *const.CAM_SIZE).clone()
                    fg_mask = fg_mask.view(-1, *const.CAM_SIZE).clone()

            divergence = self.kld(cc, fg_mask, y[1])
        else: divergence = torch.tensor(0)

        self.prev = (ace.item(), divergence.item())
        return const.LAMBDAS[2] * ace + const.LAMBDAS[3] * divergence
