#!/usr/bin/env python3

import math

import torch
import torch.nn.functional as F
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
from fairseq import metrics, utils
from fairseq.data.data_utils import lengths_to_mask, post_process
from fairseq.logging.meters import safe_round


@register_criterion("guided_label_smoothed_cross_entropy_with_ctc")
class GuidedCrossEntAccCriterion(FairseqCriterion):
    def __init__(
        self,
        task,
        sentence_avg,
        ctc_post_process="sentencepiece",
        guide_alpha=0.,
        ctc_weight=0.,
        zero_infinity=True,
        text_input_cost_ratio=1.0,
        label_smoothing=0.1,
        disable_text_guide_update_num=0,
        zero_shot=False,
        ot_weight=0.0,
        ctc_greedy_out=False,
    ):
        """
            guide_alpha:            alpha to inteplate nll and kd loss
            text_input_cost_ratio:  loss ratio for text only input data
            label_smoothing:        label smoothing ratio
            disable_text_guide_update_num:  only use nll loss for the first N updates
            attentive_cost_regularization:  ratio fo attentive cost
        """
        super().__init__(task)
        self.alpha = guide_alpha
        assert 0 <= ctc_weight <= 1
        self.ctc_weight = ctc_weight
        self.zero_infinity = zero_infinity 
        self.post_process = ctc_post_process
        self.sentence_avg = sentence_avg
        self.eps = label_smoothing
        self.text_input_cost_ratio = text_input_cost_ratio
        self.disable_update_num = disable_text_guide_update_num
        self.zero_shot = zero_shot
        if self.zero_shot:
            self.ctc_weight = 1. # ctc loss will be the main loss for speech2text task
        self.ot_weight = ot_weight
        self.ctc_greedy_out = ctc_greedy_out
        assert self.alpha >= 0 and self.alpha <= 1.0
        if self.task.source_dictionary is not None:
            self.ctc_dictionary = self.task.source_dictionary
        else:
            self.ctc_dictionary = self.task.target_dictionary
        self.blank_idx = self.ctc_dictionary.index(task.blank_symbol) if hasattr(task, 'blank_symbol') else 0
        self.pad_idx = self.ctc_dictionary.pad()
        self.eos_idx = self.ctc_dictionary.eos()
        assert self.ot_weight >= 0

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
                            help='epsilon for label smoothing, 0 means no label smoothing')
        # fmt: off
        parser.add_argument('--guide-alpha', default=0., type=float, metavar='D',
                            help='alpha to merge kd cost from text to speech input with ce loss')
        # fmt: off
        parser.add_argument('--disable-text-guide-update-num', default=0, type=int, metavar='D',
                            help='disable guided target from text for the first N updates.')
        parser.add_argument("--zero-infinity", action="store_true",
                            help="zero inf loss when source length <= target length")
        parser.add_argument("--ctc-weight", default=0.0, type=float, metavar="D",
                            help="weight of CTC loss")
        parser.add_argument("--ctc-post-process", default="sentencepiece", type=str,
                            help="how to post process predictions into words. can be letter, "
                                 "wordpiece, BPE symbols, etc. "
                                 "See fairseq.data.data_utils.post_process() for full list of options")
        parser.add_argument("--zero-shot", action="store_true",
                            help="only compute ctc loss for dual input and text nmt loss.")
        parser.add_argument("--ot-weight", default=0.0, type=float, metavar="D",
                            help="weight of OT loss")
        parser.add_argument("--ctc-greedy-out", action="store_true",
                            help="whether to store ctc greedy decoding output.")

    def forward(self, model, sample, reduce=True):
        reduction = 'sum' if reduce else 'none'
        net_input = sample["net_input"]
        is_dual_input = True if net_input['src_tokens'] is not None and net_input.get('src_txt_tokens') is not None else False
        if self.ctc_weight > 0 and is_dual_input:
            net_output, encoder_out = model(**net_input, use_encoder_outputs=True)
        else:
            net_output = model(**net_input)

        nll_loss, speech_loss, speech_nll_loss, ctc_loss, ot_loss = None, None, None, None, None
        correct, total = 0, 0
        eval_logging_output = {}
        if self.zero_shot:
            if is_dual_input:
                # only use audio and transription data, translation data is discarded.
                spch_encoder_out, text_encoder_out = encoder_out
                ctc_loss, eval_logging_output = self.compute_ctc_loss(model, net_input, spch_encoder_out)
                loss = ctc_loss
                if self.ot_weight > 0:
                    ot_loss = self.compute_ot_loss(spch_encoder_out, text_encoder_out)
                    loss += self.ot_weight * ot_loss
            else:
                lprobs = model.get_normalized_probs(net_output, log_probs=True)
                target = model.get_targets(sample, net_output)
                loss, nll_loss, correct, total = self.compute_loss_and_acc(model, lprobs, target, reduction=reduction)
                if sample["net_input"]['src_tokens'] is None:  # text input only
                    loss = loss * self.text_input_cost_ratio
        else:
            lprobs = model.get_normalized_probs(net_output, log_probs=True)
            target = model.get_targets(sample, net_output)
            if is_dual_input:
                # lprobs_spch from speech encoder and lprobs_text from text encoder
                lprobs_spch, lprobs_text = torch.chunk(lprobs, 2)
                lprobs_spch.batch_first = lprobs.batch_first
                lprobs_text.batch_first = lprobs.batch_first

                speech_loss, speech_nll_loss, speech_correct, speech_total = \
                    self.guide_loss_and_acc(model, lprobs_spch, lprobs_text, target, reduce=(reduction == 'sum'))
                text_loss, text_nll_loss, text_correct, text_total = self.compute_loss_and_acc(
                    model, lprobs_text, target, reduction=reduction)
                loss = (speech_loss + text_loss)
                nll_loss = (speech_nll_loss + text_nll_loss)
                correct = speech_correct + text_correct
                total = speech_total + text_total

                if self.ctc_weight > 0:
                    spch_encoder_out = encoder_out[0]
                    ctc_loss, eval_logging_output = self.compute_ctc_loss(model, net_input, spch_encoder_out)
                    loss = (1 - self.ctc_weight) * loss + self.ctc_weight * ctc_loss
            
                if self.ot_weight > 0:
                    ot_loss = self.compute_ot_loss(encoder_out[0], encoder_out[1])
                    loss += self.ot_weight * ot_loss
            else:
                loss, nll_loss, correct, total = self.compute_loss_and_acc(model, lprobs, target, reduction=reduction)
                if sample["net_input"]['src_tokens'] is None:  # text input only
                    loss = loss * self.text_input_cost_ratio

        sample_size, logging_output = self.get_logging_output(
            sample, loss, nll_loss, correct, total, speech_loss, speech_nll_loss, ctc_loss, ot_loss,
            eval_logging_output, is_dual_input,
        )
        return loss, sample_size, logging_output

    def compute_loss_and_acc(self, model, lprobs, target, reduction='sum'):
        if not lprobs.batch_first:
            lprobs = lprobs.transpose(0, 1)
        lprobs = lprobs.view(-1, lprobs.size(-1))  # -> (B x T) x C
        target = target.view(-1)
        loss, nll_loss = label_smoothed_nll_loss(
            lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=(reduction == 'sum'),
        )

        mask = target.ne(self.padding_idx)
        correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
        total = torch.sum(mask)
        return loss, nll_loss, correct, total

    def guide_loss_and_acc(self, model, lprobs, lprobs_teacher, target, reduce=True):
        """ lprobs_teacher is used as guide for lprobs """
        if self.alpha == 0.0 or model.num_updates < self.disable_update_num:
            return self.compute_loss_and_acc(model, lprobs, target, reduction=('sum' if reduce else 'none'))
        if not lprobs.batch_first:
            lprobs = lprobs.transpose(0, 1)
            lprobs_teacher = lprobs_teacher.transpose(0, 1)

        lprobs = lprobs.view(-1, lprobs.size(-1)).float()  # -> (B x T) x C
        lprobs_teacher = lprobs_teacher.view(-1, lprobs_teacher.size(-1)).float()  # -> (B x T) x C
        target = target.view(-1)
        loss = F.nll_loss(lprobs, target, ignore_index=self.padding_idx, reduction='sum' if reduce else 'none')
        nll_loss = loss
        probs_teacher = lprobs_teacher.exp().masked_fill_(target.unsqueeze(-1).eq(self.padding_idx), 0)
        probs_teacher = probs_teacher.detach()
        guide_loss = -(probs_teacher*lprobs).sum() if reduce else -(probs_teacher*lprobs).sum(-1, keepdim=True)
        loss = self.alpha*guide_loss + (1.0 - self.alpha)*loss

        mask = target.ne(self.padding_idx)
        correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
        total = torch.sum(mask)
        return loss, nll_loss, correct, total

    def compute_ot_loss(self, spch_encoder_out, text_encoder_out):
        spch_enc_out = spch_encoder_out["encoder_out"][0].transpose(0, 1)
        spch_enc_mask = spch_encoder_out["encoder_padding_mask"][0]
        text_enc_out = text_encoder_out["encoder_out"][0].transpose(0, 1)
        text_enc_mask = text_encoder_out["encoder_padding_mask"][0]
        loss = ipot_solver(spch_enc_out, spch_enc_mask, text_enc_out, text_enc_mask)
        return loss

    def compute_ctc_loss(self, model, net_input, encoder_out):
        transcript = net_input["src_txt_tokens"]
        #transcript_lens = net_input["src_txt_lengths"]  # this length include eos.
        if "ctc_logits" in encoder_out:
            ctc_logits = encoder_out["ctc_logits"][0]
        else:
            ctc_logits = model.encoder.compute_ctc_logits(encoder_out)
        lprobs = utils.log_softmax(ctc_logits.float(), dim=-1)
        padding_mask = encoder_out["ctc_padding_mask"][0]
        input_lengths = lprobs.new_full((lprobs.shape[1],), lprobs.shape[0]).long()
        if len(padding_mask) > 0:
            input_lengths -= padding_mask[0].sum(dim=-1)
        
        transcript_mask = (transcript != self.pad_idx) & (
            transcript != self.eos_idx
        )
        transcript_flat = transcript.masked_select(transcript_mask)
        transcript_lengths = transcript_mask.sum(-1)

        with torch.backends.cudnn.flags(enabled=False):
            loss = F.ctc_loss(
                lprobs,
                transcript_flat,
                input_lengths,
                transcript_lengths,
                blank=self.blank_idx,
                reduction="sum",
                zero_infinity=self.zero_infinity,
            )
       
        eval_logging_output = {} 
        if self.ctc_weight > 0 and not model.training:
            if self.ctc_greedy_out:
                eval_logging_output['greedy_out'] = []

            import editdistance

            with torch.no_grad():
                lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()

                c_err = 0
                c_len = 0
                w_errs = 0
                w_len = 0
                wv_errs = 0
                for lp, t, inp_l in zip(
                    lprobs_t,
                    transcript,
                    input_lengths,
                ):
                    lp = lp[:inp_l].unsqueeze(0)

                    decoded = None
                    p = (t != self.ctc_dictionary.pad()) & (t != self.ctc_dictionary.eos())
                    targ = t[p]
                    targ_units = self.ctc_dictionary.string(targ)
                    targ_units_arr = targ.tolist()

                    toks = lp.argmax(dim=-1).unique_consecutive()
                    pred_units_arr = toks[toks != self.blank_idx].tolist()

                    # TODO: check c err
                    c_err += editdistance.eval(pred_units_arr, targ_units_arr)
                    c_len += len(targ_units_arr)

                    targ_words = post_process(targ_units, self.post_process).split()

                    pred_units = self.ctc_dictionary.string(pred_units_arr)
                    pred_words_raw = post_process(pred_units, self.post_process).split()

                    if decoded is not None and "words" in decoded:
                        pred_words = decoded["words"]
                        w_errs += editdistance.eval(pred_words, targ_words)
                        wv_errs += editdistance.eval(pred_words_raw, targ_words)
                    else:
                        dist = editdistance.eval(pred_words_raw, targ_words)
                        w_errs += dist
                        wv_errs += dist

                    w_len += len(targ_words)

                    if self.ctc_greedy_out:
                        eval_logging_output["greedy_out"].append((" ".join(pred_words_raw), " ".join(targ_words)))

                eval_logging_output["wv_errors"] = wv_errs
                eval_logging_output["w_errors"] = w_errs
                eval_logging_output["w_total"] = w_len
                eval_logging_output["c_errors"] = c_err
                eval_logging_output["c_total"] = c_len

        return loss, eval_logging_output

    def get_logging_output(
        self,
        sample,
        loss,
        nll_loss=None,
        correct=0,
        total=0,
        speech_loss=None,
        speech_nll_loss=None,
        ctc_loss=None,
        ot_loss=None,
        eval_logging_output={},
        is_dual_input=False,
    ):

        sample_size = (
            sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
        )
        mul_size = 2 if is_dual_input and not self.zero_shot else 1

        logging_output = {
            "loss": utils.item(loss.data),  # * sample['ntokens'],
            "nll_loss": utils.item(nll_loss.data) if nll_loss is not None else 0,  # * sample['ntokens'],
            "ntokens": sample["ntokens"]*mul_size,
            "nsentences": sample["target"].size(0)*mul_size,
            "sample_size": sample_size*mul_size,
            "correct": utils.item(correct.data) if correct > 0 else 0,
            "total": utils.item(total.data) if total > 0 else 0,
            "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
        }

        if speech_loss is not None:
            logging_output["speech_loss"] = utils.item(speech_loss.data)
            logging_output["speech_nll_loss"] = utils.item(speech_nll_loss.data)

        if ctc_loss is not None:
            logging_output["speech_ctc_loss"] = utils.item(ctc_loss.data)
 
        if ot_loss is not None:
            logging_output["speech_ot_loss"] = utils.item(ot_loss.data)

        if speech_loss is not None or ctc_loss is not None or ot_loss is not None:
            logging_output["sample_size_speech_cost"] = sample_size

        for key in eval_logging_output:
            logging_output[key] = eval_logging_output[key]

        return sample_size*mul_size, logging_output

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
        total_sum = sum(log.get("total", 0) for log in logging_outputs)
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        nframes = sum(log.get("nframes", 0) for log in logging_outputs)
        speech_loss_sum = sum(log.get("speech_loss", 0) for log in logging_outputs)
        speech_nll_loss_sum = sum(log.get("speech_nll_loss", 0) for log in logging_outputs)
        speech_ctc_loss_sum = sum(log.get("speech_ctc_loss", 0) for log in logging_outputs)
        speech_ot_loss_sum = sum(log.get("speech_ot_loss", 0) for log in logging_outputs)
        sample_size_speech = sum(log.get("sample_size_speech_cost", 0) for log in logging_outputs)

        agg_output = {
            "loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
            "nll_loss": nll_loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
            # if args.sentence_avg, then sample_size is nsentences, and loss
            # is per-sentence loss; else sample_size is ntokens, and the loss
            # becomes per-output token loss
            "speech_loss": speech_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
            "speech_nll_loss": speech_nll_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
            "speech_ctc_loss": speech_ctc_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
            "speech_ot_loss": speech_ot_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
            "ntokens": ntokens,
            "nsentences": nsentences,
            "nframes": nframes,
            "sample_size": sample_size,
            "acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
            "correct": correct_sum,
            "total": total_sum,
            # total is the number of validate tokens
        }
        return agg_output

    @classmethod
    def reduce_metrics(cls, logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
        for k, v in agg_logging_outputs.items():
            if k in {'nsentences', 'ntokens', 'sample_size'}:
                continue
            metrics.log_scalar(k, v, round=3)

        c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        c_total = sum(log.get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)
        w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_w_errors", w_errors)
        wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_wv_errors", wv_errors)
        w_total = sum(log.get("w_total", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total", w_total)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(
                    meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
                )
                if meters["_c_total"].sum > 0
                else float("nan"),
            )

        if w_total > 0:
            metrics.log_derived(
                "wer",
                lambda meters: safe_round(
                    meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "raw_wer",
                lambda meters: safe_round(
                    meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )


def euclidean_dist(enc_out1, enc_out2):
    """
    enc_outi: B x Ti x D
    return: B x T1 x T2
    """
    return torch.cdist(enc_out1, enc_out2)


def cosine_dist(enc_out1, enc_out2, eps=1e-8):
    """
    enc_outi: B x Ti x D
    return: B x T1 x T2
    """
    enc_out1 = torch.nn.functional.normalize(enc_out1, dim=-1)
    enc_out2 = torch.nn.functional.normalize(enc_out2, dim=-1)
    return 1. - torch.bmm(enc_out1, enc_out2.transpose(1, 2))


def ipot_solver(enc_out1, enc_mask1, enc_out2, enc_mask2, 
                max_iters=50, beta=1.0, dist_func=cosine_dist, distribution='normalize'):
    len_mask1 = 1. - enc_mask1.type(enc_out1.dtype)  # B x T1
    len_mask2 = 1. - enc_mask2.type(enc_out2.dtype)  # B x T2

    len2 = len_mask2.sum(-1, keepdim=True)
    if distribution == "uniform":
        # distribution ~ 1 / seq_len
        norm1 = 1. / len_mask1.sum(-1, keepdim=True)  # B x 1
        norm2 = 1. / len2  # B x 1
    elif distribution == "normalize":
        # distribution ~  |w| / sum(|w|)
        weight1 = enc_out1.norm(dim=-1) * len_mask1 # B x T1
        weight2 = enc_out2.norm(dim=-1) * len_mask2 # B x T2
        norm1 =  weight1 / weight1.sum(dim=-1, keepdim=True) # B x T1
        norm2 =  weight2 / weight2.sum(dim=-1, keepdim=True) # B x 1

    cost_mask = torch.einsum('bi,bj->bij', len_mask1, len_mask2)  # B x T1 x T2

    cost = dist_func(enc_out1, enc_out2)   # B x T1 x T2
    transport = cost_mask
    auxiliary = torch.exp(-cost / beta)
    sigma = (1. / len2) * len_mask2  # B x T2
    for i in range(max_iters):
        q = auxiliary * transport  # B x T1 x T2
        _delta = torch.einsum('bij,bj->bi', q, sigma)  # B x T1
        delta = norm1 / _delta.masked_fill(enc_mask1, float("inf"))
        _sigma = torch.einsum('bij,bi->bj', q, delta)  # B x T2
        sigma = norm2 / _sigma.masked_fill(enc_mask2, float("inf"))
        transport = delta.unsqueeze(-1) * q * sigma.unsqueeze(1)
    
    return (cost * transport * cost_mask).sum()  # cost_mask may not necessary

