##########################################################################
# Copyright (C) 2022 COAI @ Tsinghua University

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#         http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##########################################################################

import os
import math
import sys

import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
from torch.utils.checkpoint import checkpoint
from torch import jit
from typing import Any, Dict, List, Optional, Tuple


####################### Cuda Version of DAG Oerations ####################

module_path = os.path.dirname(__file__)
dag_kernel = None

####################### Torch Version of DAG Oerations ####################


def reverse_seq(t, length):
    # t: (B, L);  length: (B,1)
    B, L = t.size()
    tmp = torch.arange(L, device=t.device).view(1, L).tile(B, 1)
    rev_t = t.gather(dim=1, index=((length - 1 - tmp) % L))
    return rev_t

def reverse_feat(t, length):
    # t: (B, L, D);  length: (B,1)
    B, L, D = t.size()
    tmp = torch.arange(L,device=t.device).view(1, L).tile(B, 1)
    rev_t = t.gather(dim=1, index=((length - 1 - tmp) % L).unsqueeze(-1).tile(1, 1, D))
    return rev_t


# @jit.script
def logsumexp_keepdim(x: Tensor, dim: int) -> Tensor:
    # Solving nan issue when x contains -inf
    # See https://github.com/pytorch/pytorch/issues/31829
    m, _ = x.max(dim=dim, keepdim=True)
    mask = m == -float('inf')
    m = m.detach()
    s = (x - m.masked_fill_(mask, 0)).exp_().sum(dim=dim, keepdim=True)
    return s.masked_fill_(mask, 1).log_() + m.masked_fill_(mask, -float('inf'))

# @jit.script
def loop_function_noempty(last_f, last_b, links_f, links_b, match_f, match_b):
    f_next = logsumexp_keepdim(last_f + links_f, 1) # batch * 1 * prelen
    f_next = f_next.transpose(1, 2) + match_f # batch * prelen * 1
    b_next = logsumexp_keepdim(last_b + links_b, 1) # batch * 1 * prelen
    b_next = b_next.transpose(1, 2) + match_b # batch * prelen * 1
    return f_next, b_next

# @jit.script
def loop_function_noempty_max(last_f: Tensor, links: Tensor, match: Tensor, wo_emit=False, emit_score=None):
    f_next, prev_max_idx = torch.max(last_f + links, dim=1) # batch * prelen
    if not wo_emit:
        f_next = f_next.unsqueeze(-1) + match # batch * prelen * 1
        new_emit_score = None
    else:
        raise ValueError("don't support now...")
        f_next = f_next.unsqueeze(-1)
        new_emit_score = emit_score.gather(dim=1, index=prev_max_idx) + match.squeeze(-1)
    return f_next, new_emit_score

def bi_dag_loss(match_all, links, output_length, target_length, mode="sum", wo_emit=False):
    r"""
    Function to calculate the dag loss.
    Input:
        match_all (torch.FloatTensor or torch.HalfTensor):
            Shape: [batch_size, max_target_length, max_output_length]
            match_all[b, i, j] represents -log P(y_i| v_j), the probability of predicting the i-th token in the reference
            based on the j-th vertex.
            (Note: float32 are preferred; float16 may cause precision problem)
        links (torch.FloatTensor or torch.HalfTensor):
            Shape: [batch_size, max_output_length, max_transition_length]
            links[b, i, j] represents the transition probability from the i-th vertex to **the j-th vertex**.
            (Note: this parameter is different from the cuda version)
        output_length (torch.LongTensor):
            Shape: [batch_size]
            output_length should be the graph size, the vertices (index >= graph size) are ignored
        target_length (torch.LongTensor):
            Shape: [batch_size]
            target_length is the reference length, the tokens (index >= target length) are ignored

    Output (torch.FloatTensor or torch.HalfTensor):
        Shape: [batch_size]
        the loss of each sample
    """
    match_all_f = match_all[0]
    match_all_b = match_all[1]

    batch_size, prelen, tarlen = match_all_f.shape
    assert links[0].shape[1] == links[0].shape[2], "links should be batch_size * prelen * prelen"

    f_init = torch.zeros(batch_size, prelen, 1, dtype=match_all_f.dtype, device=match_all_f.device).fill_(float("-inf"))
    b_init = torch.zeros(batch_size, prelen, 1, dtype=match_all_b.dtype, device=match_all_b.device).fill_(float("-inf"))
    f_init[:, 0, 0] = match_all_f[:, 0, 0] if not wo_emit else 0.
    b_init[:, 0, 0] = match_all_b[:, 0, 0] if not wo_emit else 0.
    f_arr, b_arr = [f_init], [b_init]

    if wo_emit:
        raise ValueError("don't support now...")
        emit_init = f_init.clone().squeeze(-1)
        emit_init[:, 0] = match_all[:, 0, 0]
        e_arr = [emit_init]
    else:
        e_arr = [None]

    match_all_f_chunk = torch.chunk(match_all_f, tarlen, -1) # k * [batch * prelen * 1]
    match_all_b_chunk = torch.chunk(match_all_b, tarlen, -1) # k * [batch * prelen * 1]
    for k in range(1, tarlen):
        if mode == "sum":
            f_now, b_now = loop_function_noempty(f_arr[-1], b_arr[-1], links[0], links[1],
                                          match_all_f_chunk[k], match_all_b_chunk[k])
        elif mode == "max":
            raise ValueError("don't support now...")
            f_now, emit_score = loop_function_noempty_max(f_arr[-1], links, match_all_chunk[k],
                                                            wo_emit=wo_emit, emit_score=e_arr[-1])
            e_arr.append(emit_score)
        else:
            raise ValueError("sum or max...")
        f_arr.append(f_now)
        b_arr.append(b_now)

    loss_result_f = torch.cat(f_arr, -1)[range(batch_size), output_length - 1, target_length - 1]
    loss_result_b = torch.cat(b_arr, -1)[range(batch_size), output_length - 1, target_length - 1]

    if wo_emit:
        raise ValueError("don't support now...")
        e_arr = [e.unsqueeze(-1) for e in e_arr]
        e_score = torch.cat(e_arr, -1)[range(batch_size), output_length - 1, target_length - 1]
        loss_result = loss_result + e_score
    return loss_result_f, loss_result_b

# @jit.script
def compute_dp_for_all(last_dp, links, lse_emit):
    dp_next = logsumexp_keepdim(last_dp + links, 1) # B, 1, L
    dp_next = dp_next.transpose(1, 2) + lse_emit #logsumexp_keepdim(match, 2) # the second term B, L, 1
    return dp_next


def __torch_max_loss(match_all, links, output_length, target_length, wo_emit=False):
    # match_all = match_all.transpose(1, 2)
    batch_size, prelen, tarlen = match_all.shape
    assert links.shape[1] == links.shape[2], "links should be batch_size * prelen * prelen"

    f_arr = []
    f_init = torch.zeros(batch_size, prelen, 1, dtype=match_all.dtype, device=match_all.device).fill_(float("-inf"))
    f_init[:, 0, 0] = match_all[:, 0, 0] if not wo_emit else 0.
    f_arr.append(f_init)
    if wo_emit:
        emit_init = f_init.clone().squeeze(-1)
        emit_init[:, 0] = match_all[:, 0, 0]
        e_arr = [emit_init]
    else:
        e_arr = [None]

    match_arr = torch.chunk(match_all, tarlen, -1)
    for i in range(1, tarlen):
        f_now, emit_score = loop_function_noempty_max(f_arr[-1], links, match_arr[i], wo_emit=wo_emit, emit_score=e_arr[-1])
        f_arr.append(f_now)
        e_arr.append(emit_score)

    alllogprob = torch.cat(f_arr, -1)[range(batch_size), output_length - 1, target_length - 1]
    if wo_emit:
        e_arr = [e.unsqueeze(-1) for e in e_arr]
        e_score = torch.cat(e_arr, -1)[range(batch_size), output_length - 1, target_length - 1]
        alllogprob = alllogprob + e_score

    return alllogprob

def bi_dag_best_alignment(match, links, output_length, target_length, wo_emit=False, glat_fb_adpt=False, adpt_fb_factor=None):
    # match: (B, L, T);   links: [(B, L, L), (B, L, L)]
    with torch.enable_grad():
        match.requires_grad_()
        match_b = reverse_feat(match, output_length.unsqueeze(-1)).view(-1, match.size(-1)) # (BL, T)
        match_b = reverse_seq(match_b,
                              target_length.unsqueeze(-1).tile(1, match.size(1)).view(-1, 1)).view(*match.size())
        match_fb = [match, match_b]
        check_eq = 0
        if not glat_fb_adpt:
            i = torch.randint(low=0, high=len(links), size=(1, 1))
            # i = 1  # for debug only
            alllogprob = __torch_max_loss(match_fb[i], links[i], output_length, target_length, wo_emit=wo_emit)
        else:
            alllogprob_f = __torch_max_loss(match_fb[0], links[0], output_length, target_length, wo_emit=wo_emit)
            alllogprob_b = __torch_max_loss(match_fb[1], links[1], output_length, target_length, wo_emit=wo_emit)
            mask = (alllogprob_f > alllogprob_b).type_as(alllogprob_b)
            if adpt_fb_factor is not None:
                check_eq = adpt_fb_factor.argmax(dim=-1).type_as(mask).eq(1-mask).float().mean()

            alllogprob = alllogprob_f * mask + alllogprob_b * (1 - mask)
        matchgrad = torch.autograd.grad(alllogprob.sum(), [match])[0] # B L T
    pathvalue, path = matchgrad.max(dim=2)
    path.masked_fill_(pathvalue < 0.5, -1)
    return path, check_eq


def torch_dag_logsoftmax_gather_inplace(word_ins_out, select_idx):
    r""" Fused operation of log_softmax and gather"""
    logits = torch.log_softmax(word_ins_out, -1, dtype=torch.float32)
    match = logits.gather(dim=-1, index=select_idx)
    return word_ins_out, match






##########################################################################################
def compute_max_dp_among_all(last_f, last_b, transitions, max_emit_f, max_emit_b, wo_emit=False,
                             f_emit=None, b_emit=None, out_len=None, cur_len=None, vp=None):
    # last_dp: (B, L, 1); transitions: (B, L, L); match: (B, L, V)
    if vp is not None:
        gamma = torch.zeros_like(last_f).squeeze(-1)
        gamma[range(out_len.size(0)), out_len-1] = 1.
        gamma = vp * cur_len * gamma   # previous
    else:
        gamma = torch.zeros_like(last_f).squeeze(-1)

    f_next, f_prev_max_idx = torch.max(last_f + transitions[0] - gamma.unsqueeze(1), dim=1) # B, L
    b_next, b_prev_max_idx = torch.max(last_b + transitions[1] - gamma.unsqueeze(1), dim=1) # B, L
    if not wo_emit:
        f_next = f_next.unsqueeze(-1) + max_emit_f.unsqueeze(-1)
        b_next = b_next.unsqueeze(-1) + max_emit_b.unsqueeze(-1)
        new_f_emit = new_b_emit = None
    else:
        f_next = f_next.unsqueeze(-1)
        b_next = b_next.unsqueeze(-1)
        new_f_emit = f_emit.gather(dim=1, index=f_prev_max_idx) + max_emit_f
        new_b_emit = b_emit.gather(dim=1, index=b_prev_max_idx) + max_emit_b

    return f_next, b_next, new_f_emit, new_b_emit

def _max_score_among_all(lprobs_f, transitions, output_length, force_bos_eos_decoding=False,
                         normalize_length=False, wo_emit=False, viterbi_penalty=None, max_viterbi_scale=0., source_length=None):
    batch_size, prelen, vocab_size = lprobs_f.size()
    assert transitions[0].size(1) == transitions[0].size(2)

    f_init = torch.zeros(batch_size, prelen, 1, dtype=lprobs_f.dtype, device=lprobs_f.device).fill_(float("-inf"))
    b_init = torch.zeros(batch_size, prelen, 1, dtype=lprobs_f.dtype, device=lprobs_f.device).fill_(float("-inf"))
    max_emit_f = lprobs_f.max(dim=-1)[0]  # (B, L)
    max_emit_b = reverse_seq(max_emit_f, length=output_length.unsqueeze(-1))  # (B, L)

    assert not force_bos_eos_decoding, "don't support now..."
    # f_init[:, 0, 0] = lprobs_f[:, 0, 0] if force_bos_eos_decoding else lprobs_f[:, 0, :].max(dim=-1)[0]
    # b_init[:, 0, 0] = lprobs_b[:, 0, 2] if force_bos_eos_decoding else lprobs_b[:, 0, :].max(dim=-1)[0]
    f_init[:, 0, 0] = max_emit_f[:, 0]
    b_init[:, 0, 0] = max_emit_b[:, 0]
    if wo_emit:
        f_init[:, 0, 0], b_init[:, 0, 0] = 0., 0.
    f_arr, b_arr = [f_init], [b_init]

    # assert not wo_emit, "don't support now..."
    if wo_emit:
        f_emit_init = f_init.clone().squeeze(-1)
        f_emit_init[:, 0] = max_emit_f[:, 0]
        f_e_arr = [f_emit_init]
        b_emit_init = b_init.clone().squeeze(-1)
        b_emit_init[:, 0] = max_emit_b[:, 0]
        b_e_arr = [b_emit_init]
    else:
        f_e_arr, b_e_arr = [None], [None]

    # if force_bos_eos_decoding:
    #     max_emit[range(batch_size), output_length-1] = match_any[range(batch_size), output_length-1, 2] # forcing eos; eos index is 2; B
    steps = prelen if (source_length is None or max_viterbi_scale < 1) else max_viterbi_scale * source_length.squeeze(
        -1).max()
    for i in range(1, steps):
        f_now, b_now, new_f_emit, new_b_emit = compute_max_dp_among_all(f_arr[-1], b_arr[-1], transitions,
                                                max_emit_f, max_emit_b, wo_emit=wo_emit, f_emit=f_e_arr[-1], b_emit=b_e_arr[-1],
                                                out_len=output_length, cur_len=i, vp=viterbi_penalty)
        f_arr.append(f_now)
        b_arr.append(b_now)
        f_e_arr.append(new_f_emit)
        b_e_arr.append(new_b_emit)

    all_scores_f = torch.cat(f_arr, -1) # B, L, L
    all_scores_b = torch.cat(b_arr, -1) # B, L, L
    all_scores_f = all_scores_f.gather(dim=1, index=output_length[:, None, None].tile(1, 1, steps)-1).squeeze(1) # B, L
    all_scores_b = all_scores_b.gather(dim=1, index=output_length[:, None, None].tile(1, 1, steps)-1).squeeze(1) # B, L

    if wo_emit:
        f_e_arr = [e.unsqueeze(-1) for e in f_e_arr]
        b_e_arr = [e.unsqueeze(-1) for e in b_e_arr]
        f_e_score = torch.cat(f_e_arr, -1).gather(dim=1, index=output_length[:,None,None].tile(1,1, steps)-1).squeeze(1)
        b_e_score = torch.cat(b_e_arr, -1).gather(dim=1, index=output_length[:,None,None].tile(1,1, steps)-1).squeeze(1)
        # TODO: choose the max according to the sum of emit and transition score.

    if not normalize_length:
        max_scores_f, f_max_idx = all_scores_f.max(dim=-1) # B
        max_scores_b, b_max_idx = all_scores_b.max(dim=-1) # B
        f_or_b = (max_scores_f > max_scores_b).type_as(max_scores_b)
        # f_or_b = torch.zeros_like(max_scores_b)  # for debug only
        # f_or_b = torch.ones_like(max_scores_b)  # for debug only
        max_scores = max_scores_f * f_or_b + max_scores_b * (1 - f_or_b)
        if wo_emit:
            max_scores_f = max_scores_f + f_e_score.gather(dim=1, index=f_max_idx.unsqueeze(1)).squeeze(-1)
            max_scores_b = max_scores_b + b_e_score.gather(dim=1, index=b_max_idx.unsqueeze(1)).squeeze(-1)
            max_scores = max_scores_f * f_or_b + max_scores_b * (1 - f_or_b)
    else:
        L = (torch.tensor([i for i in range(1,1+all_scores_f.size(1))], dtype=all_scores_f.dtype, device=all_scores_f.device)).unsqueeze(0).tile(all_scores_f.size(0),1)
        mask = L <= output_length[:, None]
        mask[:, :2] = False
        L.masked_fill_(~mask, 0.1)
        max_scores_f, f_max_idx = (all_scores_f/L).max(dim=-1)   # B
        max_scores_b, b_max_idx = (all_scores_b/L).max(dim=-1)   # B

        f_or_b = (max_scores_f >= max_scores_b).type_as(max_scores_b)
        # f_or_b = torch.zeros_like(max_scores_b)  # for debug only   bwd
        # f_or_b = torch.ones_like(max_scores_b)  # for debug only   fwd
        max_scores = max_scores_f * f_or_b + max_scores_b * (1 - f_or_b)
        if wo_emit:
            max_scores_f = max_scores_f + f_e_score.gather(dim=1, index=f_max_idx.unsqueeze(1)).squeeze(-1)
            max_scores_b = max_scores_b + b_e_score.gather(dim=1, index=b_max_idx.unsqueeze(1)).squeeze(-1)
            max_scores = max_scores_f * f_or_b + max_scores_b * (1 - f_or_b)
    return max_scores, f_or_b

def bi_find_best_path_among_all(lprobs_f, transitions, output_length, force_bos_eos_decoding=False,
                                normalize_length=False,
                                wo_emit=False, adpt_fb_factor=None, viterbi_penalty=None, max_viterbi_scale=0,
                                source_length=None):
    # match_any: (B, L, V)
    # return: best_path: (B, L)
    with torch.enable_grad():
        lprobs_f.requires_grad_()
        best_sequence_scores, f_or_b = _max_score_among_all(lprobs_f, transitions, output_length, force_bos_eos_decoding,
                                                            normalize_length=normalize_length, wo_emit=wo_emit,
                                                            viterbi_penalty=viterbi_penalty,
                                                            max_viterbi_scale=max_viterbi_scale,
                                                            source_length=source_length)
        best_sequence_grad = torch.autograd.grad(best_sequence_scores.sum(), [lprobs_f])[0] # size of match any
    best_path_value, best_path = best_sequence_grad.max(dim=-1)
    best_path.masked_fill_(best_path_value <= 0, -100)
    return best_path

########################################################################################################################

