##########################################################################
# Copyright (C) 2022 COAI @ Tsinghua University

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#         http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##########################################################################

import math
import re
import logging
from functools import reduce
import numpy as np
from typing import Union, Tuple, Optional
import sys
from torch import jit

import torch
from torch import Tensor
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from torch.autograd import Function
from ..custom_ops import bi_dag_best_alignment, bi_dag_loss, find_best_path_among_all


from .utilities import parse_anneal_argument, get_anneal_value

logger = logging.getLogger(__name__)

########### gpu use tracker ###########
# import inspect
SHOW_MEMORY_USE=False
if SHOW_MEMORY_USE:
    from fairseq.gpu_mem_track import MemTracker
    gpu_tracker = MemTracker()
########################################
def reverse_feat(t, length):
    # t: (B, L, D);  length: (B,1)
    B, L, D = t.size()
    tmp = torch.arange(L, device=t.device).view(1, L).tile(B, 1)
    rev_t = t.gather(dim=1, index=((length - 1 - tmp) % L).unsqueeze(-1).tile(1, 1, D))
    return rev_t

def reverse_seq(t, length):
    # t: (B, L);  length: (B,1)
    B, L = t.size()
    tmp = torch.arange(L, device=t.device).view(1, L).tile(B, 1)
    rev_t = t.gather(dim=1, index=((length - 1 - tmp) % L))
    return rev_t

def logsumexp_with_scale(x: Tensor, dim: int, fb_weight=0.5) -> Tensor:
    if isinstance(fb_weight, float):
        m, _ = x.max(dim=dim)
        mask = m == -float('inf')
        weight = fb_weight * x.new_ones(1, x.size(1))
        weight = torch.cat([weight, 1-weight], dim=0)
        s = ((x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp() * weight).sum(dim=dim)
    elif isinstance(fb_weight, torch.Tensor):
        x = x + fb_weight
        m, _ = x.max(dim=dim)
        mask = m == -float('inf')
        s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim)
    else:
        raise ValueError("fb_weight needs to be tensor or float")
    return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float('inf'))

@register_criterion("bi_dag_nat_loss")
class BINATDAGLoss(FairseqCriterion):

    def __init__(self, cfg, task):
        super().__init__(task)
        self.cfg = cfg
        assert cfg.label_smoothing == 0, "DAG does not support label smoothing"
        self.glance_strategy = cfg.glance_strategy
        self._glat_p_anneal_params = parse_anneal_argument(cfg.glat_p)

        self.set_update_num(0)

        # --- my arguments ---
        self.wo_emit = getattr(cfg, "without_emit", False)
        self.train_mode = getattr(cfg, "train_mode", "sum")
        self.glat_wo_emit = getattr(cfg, "glat_wo_emit", False)
        self.fixed_fb_factor = getattr(cfg, "fb_factor", 0.5)
        self.glat_fb_adpt = getattr(cfg, "glat_fb_adpt", False)

        self.risk_training = getattr(cfg, "risk_training", False)

    @staticmethod
    def add_args(parser):
        """Add criterions-specific arguments to the parser."""
        parser.add_argument("--label-smoothing", type=float, default=0, help="DA-Transformer does not use label smoothing for now")
        parser.add_argument("--glat-p", type=str, default="0", help="Glancing probability. 0.5:0.1@200k indicates annealing p from 0.5 to 0.1 in 200k steps.")
        parser.add_argument("--glance-strategy", type=str, default=None, help='Glancing strategy. Possible values: "number-random" or "None" or "CMLM"')
        parser.add_argument("--no-force-emit", action="store_true", help="If true, do not fix the position of glance tokens in the second forward pass")

        parser.add_argument("--torch-dag-logsoftmax-gather", action="store_true", help="Use torch implementation for logsoftmax-gather, which supports GPU and CPU device. (Cuda implementation only supports GPU)")
        parser.add_argument("--torch-dag-best-alignment", action="store_true", help="Use torch implementation for dag-best-alignment, which supports GPU and CPU device. (Cuda implementation only supports GPU)")
        parser.add_argument("--torch-dag-loss", action="store_true", help="Use torch implementation for dag-loss, which supports GPU and CPU device. (Cuda implementation only supports GPU)")

        # --- my arguments ---
        parser.add_argument("--without-emit", action="store_true")
        parser.add_argument("--train-mode", type=str, default="sum")
        parser.add_argument("--glat-wo-emit", action="store_true")

        parser.add_argument("--fb-factor", type=float, default=0.5)
        parser.add_argument("--glat-fb-adpt", action="store_true")
        parser.add_argument("--risk-training", action="store_true")

    def get_bleu(self, logits, target, transitions, output_length, source_length, viterbi_scale):
        with torch.no_grad():
            tgt_mask = target.ne(self.padding_idx)
            refs = [target[i][tgt_mask[i]].cpu().tolist() for i in range(target.size(0))]
            best_path = find_best_path_among_all(logits, transitions, output_length, source_length=source_length, max_viterbi_scale=viterbi_scale)
            align_mask = best_path >= 0
            hypos = [best_path[i][align_mask[i]].cpu().tolist() for i in range(best_path.size(0))]
            bleus = [sentence_bleu([r], h) for i, (r, h) in enumerate(zip(refs, hypos))]

        return logits.new_tensor(bleus)

    def _compute_dag_loss(self, outputs, output_masks, targets, target_masks, links, label_smoothing=0.0, name="loss",
                factor=1.0, matchmask=None, keep_word_mask=None, model=None, adpt_fb_factor=None, src_lengths=None):

        batch_size, prelen = outputs.shape[0], outputs.shape[1]
        tarlen = targets.shape[1]
        nvalidtokens = output_masks.sum()

        output_length = output_masks.sum(dim=-1)
        target_length = target_masks.sum(dim=-1)

        # if self.cfg.torch_dag_logsoftmax_gather:
        #     outputs, match_all = torch_dag_logsoftmax_gather_inplace(outputs, targets.unsqueeze(1).expand(-1, prelen, -1))
        # else:
        #     outputs, match_all = dag_logsoftmax_gather_inplace(outputs, targets.unsqueeze(1).expand(-1, prelen, -1))
        # match_all = match_all.transpose(1, 2)

        # rev_target = reverse_seq(targets, target_length.unsqueeze(-1))
        match_f = torch.log_softmax(outputs, dim=-1, dtype=torch.float32).gather(dim=-1, index=targets.unsqueeze(1).tile(1, prelen, 1))  # (B, L, T)

        if matchmask is not None and not self.cfg.no_force_emit:
            # raise ValueError("no force emit, for now...")
            # TODO: force emit!!!!!
            glat_prev_mask = keep_word_mask.unsqueeze(1)
            # match_all = match_all.masked_fill(glat_prev_mask, 0) + match_all.masked_fill(~matchmask, float("-inf")).masked_fill(~glat_prev_mask, 0).detach()
            match_f = match_f.transpose(1, 2)
            match_f = match_f.masked_fill(glat_prev_mask, 0) + match_f.masked_fill(~matchmask, float("-inf")).masked_fill(~glat_prev_mask, 0).detach()
            match_f = match_f.transpose(1, 2)

        match_b = reverse_feat(match_f, output_length.unsqueeze(-1)).view(-1, match_f.size(-1))
        match_b = reverse_seq(match_b, target_length.unsqueeze(-1).tile(1, prelen).view(-1, 1)).view(batch_size, prelen, tarlen)

        loss_result_f, loss_result_b = bi_dag_loss([match_f, match_b],
                                                     links,
                                                     output_length, target_length,
                                                     mode=self.train_mode, wo_emit=self.wo_emit)
        if adpt_fb_factor is None and not self.risk_training:
            loss_result = logsumexp_with_scale(torch.stack([loss_result_f, loss_result_b], dim=0), dim=0, fb_weight=self.fixed_fb_factor)
        else:
            loss_result = logsumexp_with_scale(torch.stack([loss_result_f, loss_result_b], dim=0), dim=0, fb_weight=adpt_fb_factor.transpose(0,1))

        invalid_masks_f = loss_result_f.isinf().logical_or(loss_result_f.isnan())
        invalid_masks_b = loss_result_b.isinf().logical_or(loss_result_b.isnan())
        invalid_masks = invalid_masks_f | invalid_masks_b
        loss_result.masked_fill_(invalid_masks, 0)
        invalid_nsentences = invalid_masks.sum().detach()

        loss = -(loss_result / target_length).mean()
        nll_loss = loss.detach()
        loss_f = -1 * (loss_result_f.detach().masked_fill(invalid_masks_f, 0) / target_length).mean()
        loss_b = -1 * (loss_result_b.detach().masked_fill(invalid_masks_b, 0) / target_length).mean()
        nsentences, ntokens = targets.shape[0], targets.ne(self.task.tgt_dict.pad()).sum()

        loss_nofactor = loss
        loss = loss * factor

        return {"name": name, "loss": loss, "nll_loss": nll_loss,
                "factor": factor, "ntokens": ntokens, "nvalidtokens": nvalidtokens, "nsentences": nsentences,
                "loss_nofactor": loss_nofactor, "invalid_nsentences": invalid_nsentences,
                "loss_f": loss_f, "loss_b": loss_b}

    def _custom_loss(self, loss, name="loss", factor=1.0):
        return {"name": name, "loss": loss, "factor": factor}

    def set_update_num(self, update_num):
        self.glat_p = get_anneal_value(self._glat_p_anneal_params, update_num)

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        # import gc
        # gc.collect()
        if SHOW_MEMORY_USE:
            print(torch.cuda.memory_reserved() / 1024 / 1024, file=sys.stderr, flush=True)
            gpu_tracker.clear_cache()
        # gpu_tracker.track()

        # B x T
        src_tokens, src_lengths = (
            sample["net_input"]["src_tokens"],
            sample["net_input"]["src_lengths"],
        )
        tgt_tokens = sample["target"]

        if SHOW_MEMORY_USE:
            print(sample["net_input"]["src_tokens"].shape[0], sample["net_input"]["src_tokens"].shape[1], tgt_tokens.shape[1], file=sys.stderr, end=" ")

        if sample.get("update_num", None) is not None: # in training
            self.set_update_num(sample['update_num'])

        prev_output_tokens = model.initialize_output_tokens_by_tokens(src_tokens, tgt_tokens)

        if self.glat_p == 0:
            glat = None
        else:
            glat = {
                "context_p": max(self.glat_p, 0),
                "require_glance_grad": False
            }

        def glat_function(model, word_ins_out, tgt_tokens, prev_output_tokens, glat, links=None, adpt_fb_factor=None):
            batch_size, prelen, _ = links[0].shape if isinstance(links, list) else links.shape
            tarlen = tgt_tokens.shape[1]
            nonpad_positions = ~tgt_tokens.eq(model.pad)
            target_length = (nonpad_positions).sum(1)
            output_length = prev_output_tokens.ne(model.pad).sum(1)

            pred_tokens = word_ins_out.argmax(-1)
            # if self.cfg.torch_dag_logsoftmax_gather:
            #     word_ins_out, match = torch_dag_logsoftmax_gather_inplace(word_ins_out, tgt_tokens.unsqueeze(1).expand(-1, prelen, -1))
            # else:
            #     word_ins_out, match = dag_logsoftmax_gather_inplace(word_ins_out, tgt_tokens.unsqueeze(1).expand(-1, prelen, -1))

            rev_target = reverse_seq(tgt_tokens, target_length.unsqueeze(-1))
            match_f = torch.log_softmax(word_ins_out, dim=-1, dtype=torch.float32).gather(dim=-1,
                                                                     index=tgt_tokens.unsqueeze(1).tile(1, prelen, 1))  # (B, L, T)
            # match_b = reverse_feat(match_f, output_length.unsqueeze(-1)).view(-1, match_f.size(-1))
            # match_b = reverse_seq(match_b,
            #                       target_length.unsqueeze(-1).tile(1, prelen).view(-1, 1)).view(batch_size, prelen, tarlen)

            path, check_eq = bi_dag_best_alignment(match_f, links, output_length, target_length, self.glat_wo_emit,
                                            glat_fb_adpt=self.glat_fb_adpt, adpt_fb_factor=adpt_fb_factor)

            predict_align_mask = path >= 0
            matchmask = torch.zeros(batch_size, tarlen + 1, prelen, device=match_f.device, dtype=torch.bool).scatter_(1, path.unsqueeze(1) + 1, 1)[:, 1:]
            oracle = tgt_tokens.gather(-1, path.clip(min=0)) # bsz * prelen
            same_num = ((pred_tokens == oracle) & predict_align_mask).sum(1)

            if self.glance_strategy is None:
                keep_prob = ((target_length - same_num) / target_length * glat['context_p']).unsqueeze(-1) * predict_align_mask.float()

            elif self.glance_strategy in ['number-random']:
                prob = torch.randn(oracle.shape, device=tgt_tokens.device, dtype=torch.float)
                prob.masked_fill_(~predict_align_mask, -100)
                glance_nums = ((target_length - same_num) * glat['context_p'] + 0.5).to(torch.long)
                #prob_thresh = prob.topk(glance_nums.max().clip(min=1))[0].gather(-1, (glance_nums - 1).clip(min=0).unsqueeze(-1)).squeeze(-1)
                prob_thresh = prob.sort(descending=True)[0].gather(-1, (glance_nums - 1).clip(min=0).unsqueeze(-1)).squeeze(-1)
                prob_thresh.masked_fill_(glance_nums == 0, 100)
                keep_prob = (prob >= prob_thresh.unsqueeze(-1)).to(prob.dtype)

            elif self.glance_strategy == "cmlm":
                prob = torch.randn(oracle.shape, device=tgt_tokens.device, dtype=torch.float)
                prob.masked_fill_(~predict_align_mask, -100)
                glance_nums = (target_length * torch.rand_like(target_length, dtype=torch.float) + 0.5).to(torch.long)
                #prob_thresh = prob.topk(glance_nums.max().clip(min=1))[0].gather(-1, (glance_nums - 1).clip(min=0).unsqueeze(-1)).squeeze(-1)
                prob_thresh = prob.sort(descending=True)[0].gather(-1, (glance_nums - 1).clip(min=0).unsqueeze(-1)).squeeze(-1)
                prob_thresh.masked_fill_(glance_nums == 0, 100)
                keep_prob = (prob >= prob_thresh.unsqueeze(-1)).to(prob.dtype)

            keep_word_mask = (torch.rand(prev_output_tokens.shape, device=prev_output_tokens.device) < keep_prob).bool()

            glat_prev_output_tokens = prev_output_tokens.masked_fill(keep_word_mask, 0) + oracle.masked_fill(~keep_word_mask, 0)
            glat_tgt_tokens = tgt_tokens

            glat_info = {
                "glat_accu": (same_num.sum() / target_length.sum()).detach(),
                "glat_context_p": glat['context_p'],
                "glat_keep": keep_prob.mean().detach(),
                "matchmask": matchmask,
                "keep_word_mask": keep_word_mask,
                "glat_prev_output_tokens": glat_prev_output_tokens,
                "check_eq": check_eq,
            }

            return glat_prev_output_tokens, glat_tgt_tokens, glat_info

        outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens, glat, glat_function)

        losses = []

        # DAG loss
        _losses = self._compute_dag_loss(
            outputs["word_ins"].get("out"),
            prev_output_tokens.ne(self.task.tgt_dict.pad()),
            outputs["word_ins"].get("tgt"),
            outputs["word_ins"].get("mask", None),
            outputs["links"],
            name="dag-loss",
            factor=1,
            matchmask=outputs.get('matchmask', None),
            keep_word_mask=outputs.get('keep_word_mask', None),
            model=model,
            adpt_fb_factor=outputs.get("adpt_fb_factor", None),
            src_lengths = src_lengths
        )

        losses += [_losses]
        dag_nll_loss = _losses.get("nll_loss", 0.0)
        nsentences = _losses["nsentences"]
        ntokens = _losses["ntokens"]
        nvalidtokens = _losses["nvalidtokens"]
        invalid_nsentences = _losses["invalid_nsentences"]
        loss_f = _losses["loss_f"]
        loss_b = _losses["loss_b"]

        length_nll_loss = _losses["nll_loss"].new_zeros(1)

        loss = sum(l["loss"] for l in losses)

        sample_size = 1
        logging_output = {
            "loss": loss.data,
            "dag_nll-loss": dag_nll_loss.data,
            "length_nll-loss": length_nll_loss.data,
            "ntokens": ntokens,
            "nvalidtokens": nvalidtokens,
            "nsentences": nsentences,
            "invalid_nsentences": invalid_nsentences,
            "sample_size": sample_size,
            "glat_acc": outputs.get("glat_accu", 0),
            "glat_keep": outputs.get("glat_keep", 0),
            "check_eq": outputs.get("check_eq", 0),
            "loss_f": loss_f.data,
            "loss_b": loss_b.data
        }

        for l in losses:
            logging_output[l["name"]] = (
                utils.item(l["loss_nofactor"])
                if reduce
                else l["loss_nofactor"]
            )

        # gpu_tracker.track()
        return loss, sample_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs)
        )  # each batch is 1
        loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))  # token-level loss

        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nvalidtokens = sum(log.get('nvalidtokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        invalid_nsentences = sum(log.get('invalid_nsentences', 0) for log in logging_outputs)
        loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))  # token-level loss
        loss_f = utils.item(sum(log.get("loss_f", 0) for log in logging_outputs))
        loss_b = utils.item(sum(log.get("loss_b", 0) for log in logging_outputs))

        glat_acc = utils.item(sum(log.get("glat_acc", 0) for log in logging_outputs))
        glat_keep = utils.item(sum(log.get("glat_keep", 0) for log in logging_outputs))
        check_eq = utils.item(sum(log.get("check_eq", 0) for log in logging_outputs))

        res = {
            "ntokens": utils.item(ntokens),
            "nsentences": utils.item(nsentences),
            "nvalidtokens": utils.item(nvalidtokens),
            "invalid_nsentences": utils.item(invalid_nsentences),
            'tokens_perc': utils.item(nvalidtokens / ntokens),
            'sentences_perc': 1 - utils.item(invalid_nsentences / nsentences),
        }
        res["loss"] = loss / sample_size
        res["glat_acc"] = glat_acc / sample_size
        res["glat_keep"] = glat_keep / sample_size
        res["check_eq"] = check_eq / sample_size

        res["loss_f"] = loss_f / sample_size
        res["loss_b"] = loss_b / sample_size

        for key, value in res.items():
            metrics.log_scalar(
                key, value, sample_size, round=3
            )

        for key in logging_outputs[0]:
            if key[-5:] == "-loss":
                val = utils.item(sum(log.get(key, 0) for log in logging_outputs))
                metrics.log_scalar(
                    key[:-5],
                    val / sample_size if sample_size > 0 else 0.0,
                    sample_size,
                    round=3,
                )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True
