from dataclasses import dataclass, field
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterionConfig
from fairseq.criterions.label_smoothed_cross_entropy import register_criterion, LabelSmoothedCrossEntropyCriterion
import torch

@dataclass
class UtilizationCriterionConfig(LabelSmoothedCrossEntropyCriterionConfig):
    oversmoothing_weight: float = field(
        default=0,
        metadata={}
    )
    oversmoothing_margin: float = field(
        default=0,
        metadata={}
    )
    unlikelihood_weight: float = field(
        default=0,
        metadata={}
    )
    utilization_weight: float = field(
        default=0,
        metadata={}
    )
    utilization_type: str = field(
        default='u',
        metadata={}
    )


def compute_unlikelihood_loss(lprobs, pad_idx, target):
    with torch.no_grad():
        ctx_cands = target.unsqueeze(0).expand(target.size(0), target.size(0))
        ctx_cands_ = (ctx_cands.tril(-1) + pad_idx)
        ctx_cands_ = ctx_cands_ * ctx_cands_.triu()
        ctx_cands = ctx_cands.tril(-1) + ctx_cands_

        ctx_cands = ctx_cands.masked_fill(ctx_cands == target.unsqueeze(1), pad_idx)
        negative_targets = torch.zeros_like(lprobs).scatter_(1, ctx_cands, 1)

    one_minus_probs = torch.clamp((1.0 - lprobs.exp()), min=1e-5)

    unl_loss = -torch.log(one_minus_probs)*negative_targets
    unl_loss = unl_loss.sum()
    return unl_loss


def compute_oversmoothing_logratio(logits, target, non_pad_mask, eos_idx, margin=1e-5):
    full_lprobs = torch.log_softmax(logits, dim=-1)
    target_lprobs = torch.gather(full_lprobs, dim=-1, index=target.unsqueeze(-1))
    target_lprobs_withoutpad = (target_lprobs * non_pad_mask).squeeze(-1)
    suffix_lprob = target_lprobs_withoutpad + torch.sum(target_lprobs_withoutpad, dim=-1, keepdims=True) - torch.cumsum(target_lprobs_withoutpad, dim=-1)

    eos_lprobs = full_lprobs[:,:,eos_idx] * non_pad_mask.squeeze(-1)

    oversmoothing_loss = torch.maximum(eos_lprobs - suffix_lprob + margin, torch.zeros_like(suffix_lprob))
    oversmoothing_loss = (oversmoothing_loss.sum(dim=1) / non_pad_mask.squeeze(dim=-1).sum(dim=1)).mean()
    with torch.no_grad():
        oversmoothed = eos_lprobs > suffix_lprob
        oversmoothed = oversmoothed * non_pad_mask.squeeze(-1)
        oversmoothed = oversmoothed * (target != eos_idx).float()

        num_osr_per_seq = non_pad_mask.squeeze(-1).sum(-1) - 1
        osr = oversmoothed.sum(-1) / num_osr_per_seq

    return oversmoothing_loss, osr


@register_criterion("utilization_nll_loss",
                    dataclass=UtilizationCriterionConfig)
class UtilizationCriterion(LabelSmoothedCrossEntropyCriterion):
    def __init__(self, task, sentence_avg,
                 label_smoothing,
                 ignore_prefix_size=0,
                 report_accuracy=False,
                 oversmoothing_weight=0.0,
                 oversmoothing_margin=0,
                 unlikelihood_weight=0.0,
                 utilization_weight=0.0,
                 utilization_type='u'):

        super().__init__(task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy)
        self.pad_idx = task.tgt_dict.pad()
        self.eos_idx = task.tgt_dict.eos()
        self.label_smoothing_eps = label_smoothing
        self.oversmoothing_margin = oversmoothing_margin
        self.oversmoothing_weight = oversmoothing_weight
        self.unlikelihood_weight = unlikelihood_weight
        self.utilization_weight = utilization_weight
        self.utilization_type = utilization_type

    @staticmethod
    def marginalize_probs(concept_probs, non_pad_mask):
        concept_probs = concept_probs.to(torch.float32)
        concept_probs = torch.clamp(concept_probs, min=torch.finfo(torch.float32).tiny)
        concept_distr = concept_probs * non_pad_mask
        marginal_concept_distr = concept_distr.sum(dim=1) / non_pad_mask.sum(1)
        return marginal_concept_distr

    def compute_utilization_loss(self, sample, logits, non_pad_mask):
        concepts = sample['concepts']
        concepts_weights_sample = sample['concepts_weights']
        semantic_type_weights_sample = sample['sem_type_weights']

        concept_probs = torch.gather(logits.softmax(-1), -1, concepts.unsqueeze(-1))  # [bsz, seq_len, concepts]
        concept_distr = self.marginalize_probs(concept_probs, non_pad_mask)  # [bsz, concepts]

        if self.utilization_type == 'c':
            return (concept_distr * concepts_weights_sample).sum() / semantic_type_weights_sample.sum()
        elif self.utilization_type == 's':
            return (concept_distr * semantic_type_weights_sample).sum() / semantic_type_weights_sample.sum()
        return concept_distr.mean()

    def forward(self, model, sample, reduce=True):
        updated_features, extra_stats = model(**sample["net_input"])

        loss, nll_loss = self.compute_loss(model, (updated_features, extra_stats), sample, reduce=reduce)

        sample_size = (
            sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
        )

        logging_output = {
            "nll_loss": nll_loss.item(),
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }

        non_pad_mask = (sample["target"] != 1).float()[:, :, None]

        target = model.get_targets(sample, updated_features)

        oversmoothing_loss, osr_per_sequence = compute_oversmoothing_logratio(updated_features, target, non_pad_mask,
                                                                              self.eos_idx, self.oversmoothing_margin)
        unlikelihood_loss = compute_unlikelihood_loss(updated_features.log_softmax(dim=-1), self.pad_idx, target)
        utilization_loss = self.compute_utilization_loss(sample, updated_features, non_pad_mask)

        loss = loss + self.oversmoothing_weight * oversmoothing_loss + self.unlikelihood_weight * unlikelihood_loss + \
               self.utilization_weight * utilization_loss

        logging_output['utilization_loss'] = utilization_loss.item()
        logging_output['loss'] = loss.item()

        return loss, sample_size, logging_output
