# -*- coding:utf-8 _*-
# @License: MIT Licence

# @Time: 23/5/2023
import os
import re
import math
from typing import List, Optional, Tuple, Union, Callable
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
import inspect
import sys

sys.path.append(os.path.abspath(os.path.join(__file__, '../../src')))
from model import Block, GPT2Model, Attention, GPT2LMModel


def disable_dropout_Modified(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Dropout) or "drop" in name:
            assert 'lora_dropout' in name, "dropout in other modules are not supported yet"
            module.p = 0.0
    return model


def replace_dropout_Modified(model, modified_dropout, model_name="none"):
    assert model_name in ["gpt2.md", ], "other models not supported yet"

    if "hiddencut_element" in modified_dropout and modified_dropout["hiddencut_element"] > 0.:
        dropout_pattern, dropout_rate = "hiddencut_element", modified_dropout["hiddencut_element"]
        for name, module in model.named_modules():
            if isinstance(module, Block):  # if re.match(r"^transformer\.h\.\d{1,2}$", name):
                module.modified_dropout = nn.Dropout(p=dropout_rate)
                module.modified_dropout.dropout_pattern = dropout_pattern
                layer_idx = int(re.findall(r"\d{1,2}", name)[-1])
                if layer_idx < 12:
                    module.modified_dropout.p = 0.0
        Block.forward = Block_forward_Modified

    elif "hiddencut_column" in modified_dropout and modified_dropout["hiddencut_column"] > 0.:
        dropout_pattern, dropout_rate = "hiddencut_column", modified_dropout["hiddencut_column"]
        for name, module in model.named_modules():
            if isinstance(module, Block):
                module.modified_dropout = Dropout_Modified(p=dropout_rate, dropout_pattern=dropout_pattern)
                layer_idx = int(re.findall(r"\d{1,2}", name)[-1])
                if layer_idx < 12:
                    module.modified_dropout.p = 0.0
        Block.forward = Block_forward_Modified

    elif "hiddencut_span" in modified_dropout and modified_dropout["hiddencut_span"] > 0.:
        dropout_pattern, dropout_rate = "hiddencut_span", modified_dropout["hiddencut_span"]
        for name, module in model.named_modules():
            if isinstance(module, Block):
                module.modified_dropout = Dropout_Modified(p=dropout_rate, dropout_pattern=dropout_pattern)
                layer_idx = int(re.findall(r"\d{1,2}", name)[-1])
                if layer_idx < 12:
                    module.modified_dropout.p = 0.0
        Block.forward = Block_forward_Modified
        GPT2Model.forward = GPT2Model_forward_Modified

    if any((i in modified_dropout and modified_dropout[i] > 0.)
           for i in ["dropkey_element", "dropkey_column", "dropkey_span", "dropattn_element", "dropattn_column",
                     "dropattn_span", ]):
        Attention._attn = Attention__attn_Modified
        for name, module in model.named_modules():
            if isinstance(module, Attention):
                layer_idx = int(re.findall(r"\d{1,2}", name)[-1])
                if layer_idx < 12:
                    module.modified_dropout = {}
                else:
                    module.modified_dropout = modified_dropout

    return model


class Dropout_Modified(nn.Dropout):

    def __init__(self, p: float = 0.5, inplace: bool = False, dropout_pattern: str = "none"):
        super().__init__(p, inplace)
        self.dropout_pattern = dropout_pattern
        assert self.dropout_pattern in ["hiddencut_column", "hiddencut_span", ]

    def forward(self, input: Tensor, modified_input_len=None) -> Tensor:
        if self.p < 0.0 or self.p > 1.0:
            raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(self.p))

        if self.p == 0 or not self.training:
            return input

        # training mode
        with torch.no_grad():
            bz, seq_len, _ = input.size()
            if self.dropout_pattern == "hiddencut_column":
                mask = torch.ones(bz, seq_len, 1)
                mask = F.dropout(mask, self.p, self.training, self.inplace) != 0
                mask = mask.expand_as(input).type_as(input)
            elif self.dropout_pattern == "hiddencut_span":
                dropout_rate = self.p
                emb_len = modified_input_len
                mask_len = (emb_len * dropout_rate).long().clamp(min=0)
                index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
                start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
                end_indices = (start_indices + mask_len).clamp(max=emb_len)
                mask = torch.ones((bz, seq_len)).type_as(input)
                mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
                     & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = 0.
                mask = mask.unsqueeze(-1).expand_as(input)

            else:
                raise NotImplementedError("dropout pattern {} not implemented".format(self.dropout_pattern))

        if self.inplace:
            input *= mask
            return input
        else:
            output = input * mask
            return output


def Block_forward_Modified(self, x, layer_past=None, len_past=None, modified_input_len=None):
    a, present = self.attn(self.ln_1(x), layer_past=layer_past, len_past=len_past)
    x = x + a
    m = self.mlp(self.ln_2(x))
    m = self.modified_dropout(m, modified_input_len) \
        if 'modified_input_len' in inspect.signature(self.modified_dropout.forward).parameters \
        else self.modified_dropout(m)
    x = x + m
    return x, present


def GPT2Model_forward_Modified(
        self,
        input_ids,
        position_ids=None,
        token_type_ids=None,
        past=None,
        len_past=None,
):
    if past is None:
        past_length = 0
        past = [None] * len(self.h)
    elif len_past is None:
        # equal size for past. []
        past_length = past[0][0].size(-2)

    if position_ids is None and len_past is None:
        position_ids = torch.arange(
            past_length, input_ids.size(-1) + past_length,
            dtype=torch.long, device=input_ids.device
        )
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    elif len_past is not None:
        position_ids = (len_past).unsqueeze(1)  # .long()

    input_shape = input_ids.size()
    input_ids = input_ids.view(-1, input_ids.size(-1))
    position_ids = position_ids.view(-1, position_ids.size(-1))

    inputs_embeds = self.wte(input_ids)

    position_embeds = self.wpe(position_ids)

    if token_type_ids is not None:
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
        token_type_embeds = self.wte(token_type_ids)
    else:
        token_type_embeds = 0
    hidden_states = inputs_embeds + position_embeds + token_type_embeds
    presents = []
    modified_input_len = (input_ids != 0).sum(-1)
    for block, layer_past in zip(self.h, past):
        hidden_states, present = block(hidden_states, layer_past=layer_past, len_past=len_past,
                                       modified_input_len=modified_input_len)
        presents.append(present)
    hidden_states = self.ln_f(hidden_states)
    output_shape = input_shape + (hidden_states.size(-1),)
    return hidden_states.view(*output_shape), presents


def Attention__attn_Modified(self, q, k, v, len_kv=None):
    w = torch.matmul(q, k)
    if self.scale:
        w = w / math.sqrt(v.size(-1))
    nd, ns = w.size(-2), w.size(-1)
    assert nd == 512 and ns == 512, "nd, ns should be 512, but got {} {}".format(nd, ns)
    b = self.bias[:, :, ns - nd:ns, :ns]
    w = w * b - 1e10 * (1 - b)

    # q : (batch, head, q_seq_length, head_features)
    # k : (batch, head, head_features, kv_seq_length)
    # w : (batch, head, q_seq_length, kv_seq_length)
    # v : (batch, head, kv_seq_length, head_features)
    if len_kv is not None:
        _len = torch.arange(k.size(-1), device=k.device)
        _input_msk = _len[None, :] >= (len_kv)[:, None]
        w = w.masked_fill(_input_msk.unsqueeze(1).unsqueeze(2), -1.0e10)

    attention_scores = w

    if not self.training:
        pass
    elif "dropkey_element" in self.modified_dropout and self.modified_dropout["dropkey_element"] > 0.:
        dropout_rate = self.modified_dropout["dropkey_element"]
        bz, hd, _, seq_len = attention_scores.size()

        with torch.no_grad():
            attention_mask = torch.ones_like(attention_scores).tril().bool().view(-1, seq_len)
            drop_mask = torch.bernoulli(torch.ones_like(attention_scores) * (1 - dropout_rate)) \
                .type_as(attention_scores).bool().view(-1, seq_len)  # 0: drop 1: keep
            drop_mask.view_as(attention_scores)[..., 0, 0] = True
            idx_0 = torch.where((attention_mask & drop_mask).sum(-1) == 0)[0]
            while len(idx_0) > 0:
                patch_mask = torch.bernoulli(torch.ones((len(idx_0), seq_len)) * (1 - dropout_rate)) \
                    .type_as(attention_scores).bool()  # 0: drop 1: keep
                drop_mask[idx_0] = patch_mask
                idx_0 = idx_0[torch.where((attention_mask[idx_0] & patch_mask).sum(-1) == 0)[0]]
            drop_mask = (~ drop_mask.view(bz, hd, seq_len, seq_len)).float()  # 0: keep 1: drop
            drop_mask[drop_mask == 1] = float("-inf")
        attention_scores = attention_scores + drop_mask

    elif "dropkey_column" in self.modified_dropout and self.modified_dropout["dropkey_column"] > 0.:
        dropout_rate = self.modified_dropout["dropkey_column"]
        bz, hd, _, seq_len = attention_scores.size()
        drop_mask = torch.bernoulli(torch.ones((bz, hd, 1, seq_len)) * dropout_rate) \
            .type_as(attention_scores)  # 0: keep 1: drop
        drop_mask[drop_mask == 1] = float("-inf")
        drop_mask[..., 0, 0] = 0.
        attention_scores = attention_scores + drop_mask

    elif "dropkey_span" in self.modified_dropout and self.modified_dropout["dropkey_span"] > 0.:
        dropout_rate = self.modified_dropout["dropkey_span"]
        bz, hd, _, seq_len = attention_scores.size()
        emb_len = torch.arange(1, seq_len + 1).unsqueeze(0).unsqueeze(0).repeat(bz, hd, 1)
        emb_len = emb_len.view(bz, -1)
        mask_len = (emb_len * dropout_rate).long().clamp(min=0)
        index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
        start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
        end_indices = (start_indices + mask_len).clamp(max=emb_len)
        drop_mask = torch.zeros((bz, hd * seq_len, seq_len)).type_as(attention_scores)

        drop_mask = drop_mask.view(-1, seq_len)
        drop_mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
                  & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = float("-inf")
        drop_mask = drop_mask.view(bz, hd, seq_len, seq_len)
        drop_mask[..., 0, 0] = 0.
        attention_scores = attention_scores + drop_mask

    attention_probs = nn.Softmax(dim=-1)(attention_scores)

    if not self.training:
        pass
    elif "dropattn_element" in self.modified_dropout and self.modified_dropout["dropattn_element"] > 0.:
        dropout_rate = self.modified_dropout["dropattn_element"]
        bz, hd, _, seq_len = attention_scores.size()

        with torch.no_grad():
            attention_mask = torch.ones_like(attention_probs).tril().bool().view(-1, seq_len)  # True: keep  False: drop
            drop_mask = torch.bernoulli(torch.ones_like(attention_probs) * (1 - dropout_rate)) \
                .type_as(attention_probs).bool().view(-1, seq_len)  # 0: drop 1: keep
            drop_mask.view_as(attention_probs)[..., 0, 0] = True
            idx_0 = torch.where((attention_mask & drop_mask).sum(-1) == 0)[0]
            while len(idx_0) > 0:
                patch_mask = torch.bernoulli(torch.ones((len(idx_0), seq_len)) * (1 - dropout_rate)) \
                    .type_as(attention_probs).bool()  # 0: drop 1: keep
                drop_mask[idx_0] = patch_mask
                idx_0 = idx_0[torch.where((attention_mask[idx_0] & patch_mask).sum(-1) == 0)[0]]
            drop_mask = drop_mask.view(bz, hd, seq_len, seq_len)
        attention_probs = attention_probs * drop_mask
        attention_probs = attention_probs / (attention_probs.sum(-1, keepdim=True).detach() + 1e-6)

    elif "dropattn_column" in self.modified_dropout and self.modified_dropout["dropattn_column"] > 0.:
        dropout_rate = self.modified_dropout["dropattn_column"]
        bz, hd, _, seq_len = attention_probs.size()
        drop_mask = torch.bernoulli(torch.ones((bz, hd, 1, seq_len)) * (1 - dropout_rate)) \
            .type_as(attention_scores)  # 0: drop 1: keep
        drop_mask[..., 0, 0] = 1.
        attention_probs = attention_probs * drop_mask
        attention_probs = attention_probs / (attention_probs.sum(-1, keepdim=True).detach() + 1e-6)

    elif "dropattn_span" in self.modified_dropout and self.modified_dropout["dropattn_span"] > 0.:
        dropout_rate = self.modified_dropout["dropattn_span"]
        bz, hd, _, seq_len = attention_scores.size()
        emb_len = torch.arange(1, seq_len + 1).unsqueeze(0).unsqueeze(0).repeat(bz, hd, 1)
        emb_len = emb_len.view(bz, -1)
        mask_len = (emb_len * dropout_rate).long().clamp(min=0)
        index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
        start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
        end_indices = (start_indices + mask_len).clamp(max=emb_len)
        drop_mask = torch.ones((bz, hd * seq_len, seq_len)).type_as(attention_scores)

        drop_mask = drop_mask.view(-1, seq_len)
        drop_mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
                  & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = 0.
        drop_mask = drop_mask.view(bz, hd, seq_len, seq_len)

        drop_mask[..., 0, 0] = 1.
        attention_probs = attention_probs * drop_mask
        attention_probs = attention_probs / (attention_probs.sum(-1, keepdim=True).detach() + 1e-6)

    if attention_probs.isnan().any():
        print("attention_probs.isnan().sum() > 0")
        print(attention_probs)
        print("max:", attention_probs.max(), "min:", attention_probs.min())
        exit(-1)
    return torch.matmul(attention_probs, v)


def aug_loss_Modified(model, modified_aug_loss, modified_aug_loss_weight):
    for name, module in model.named_modules():
        if isinstance(module, GPT2LMModel):  # if re.match(r"^transformer\.h\.\d{1,2}$", name):
            module.modified_aug_loss = {modified_aug_loss: float(modified_aug_loss_weight)}
    GPT2LMModel.forward = GPT2LMModel_forward_Modified

    return model


def GPT2LMModel_forward_Modified(
        self,
        input_ids,
        lm_labels=None,
        lm_mask=None,
        past=None,
        len_past=None,
        label_smooth=0.0,
        is_report_accuracy=False
):
    _batch, _len = input_ids.shape
    hidden_states, presents = self.transformer(input_ids, past=past, len_past=len_past)

    # batch, seq, vocab
    lm_logits = self.lm_head(hidden_states)

    if lm_labels is not None:

        if is_report_accuracy:
            _pred_token = torch.argmax(lm_logits, dim=-1)
            _hit = (_pred_token == lm_labels) * lm_mask

            _t1_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)
            _all_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)

            for _b in range(0, _batch):
                for _i in range(0, _len):
                    if lm_mask[_b, _i] >= 1.0:
                        if _hit[_b, _i] > 0:
                            _t1_acc[_b] = 1.0
                        break

                _is_succ = True
                for _i in range(0, _len):
                    if lm_mask[_b, _i] >= 1.0:
                        if _hit[_b, _i] <= 0:
                            _is_succ = False
                            break

                if _is_succ:
                    _all_acc[_b] = 1.0

            # _t1_acc = _t1_acc * 1.0 / _batch
            # _all_acc = _all_acc * 1.0 / _batch

        if label_smooth > 0.0001:
            logprobs = torch.nn.functional.log_softmax(lm_logits.view(-1, lm_logits.size(-1)), dim=-1)
            nll_loss = -logprobs.gather(dim=-1, index=lm_labels.view(-1).unsqueeze(1))
            nll_loss = nll_loss.squeeze(1)
            smooth_loss = -logprobs.mean(dim=-1)
            loss = (1.0 - label_smooth) * nll_loss + label_smooth * smooth_loss
            loss = loss.view(_batch, _len)
        else:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduce=False)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)).view(_batch, _len)

        if lm_mask is None:
            lm_mask = torch.ones(loss.shape, dtype=loss.dtype, device=loss.device)
        loss = loss * lm_mask
        loss = loss.sum() / (lm_mask.sum() + 0.0001)

        # aug loss: kl or js
        aug_loss = aug_kl_js_loss_Modified(transformer=self.transformer, lm_head=self.lm_head, input_ids=input_ids,
                                          lm_mask=lm_mask, past=past, len_past=len_past,
                                          modified_aug_loss=self.modified_aug_loss, src_logits=lm_logits)

        loss = loss + aug_loss.type_as(loss)

        if is_report_accuracy:
            return lm_logits, loss, _t1_acc, _all_acc
        else:
            return lm_logits, loss
    return lm_logits, presents


def aug_kl_js_loss_Modified(transformer, lm_head, input_ids, lm_mask, past, len_past, src_logits, modified_aug_loss):
    '''
    implement the loss function of Modified
    Args:
        model:
        inputs:
        src_outs:

    Returns:

    '''

    assert set(modified_aug_loss.keys()) <= {"kl", "js"}, \
        "modified_aug_loss should be in ['kl', 'js'], but got {}".format(modified_aug_loss.keys())
    assert len(modified_aug_loss) == 1, "only support one modified_aug_loss now"

    if modified_aug_loss == {}:
        aug_loss = torch.tensor(0.)

    elif "kl" in modified_aug_loss:
        with torch.no_grad():
            _batch, _len = input_ids.shape
            hidden_states, presents = transformer(input_ids, past=past, len_past=len_past)
            tgt_logits = lm_head(hidden_states)
            # tgt_logits = hidden_states

        tgt_logits = tgt_logits.detach()
        src_logits = src_logits.view(-1, src_logits.size(-1))
        tgt_logits = tgt_logits.view(-1, tgt_logits.size(-1))

        aug_loss = compute_kl_loss_Modified(src_logits, tgt_logits, reduction="none")
        aug_loss = aug_loss.sum(dim=-1).view(_batch, _len)
        if lm_mask is None:
            lm_mask = torch.ones(aug_loss.shape, dtype=aug_loss.dtype, device=aug_loss.device)
        aug_loss = aug_loss.sum() / (lm_mask.sum() + 0.0001)

        aug_loss = aug_loss * modified_aug_loss["kl"]

    elif "js" in modified_aug_loss:
        assert False, "not implemented yet"
        # model_status = model.training
        # with torch.no_grad():
        #     model.eval()
        #     tgt_outs = model(**inputs)
        # if model_status:
        #     model.train()
        # else:
        #     model.eval()
        # aug_loss = compute_js_loss_Modified(src_outs.logits, tgt_outs.logits, reduction="batchmean")
        # aug_loss = aug_loss * modified_aug_loss["js"]

    else:
        raise NotImplementedError(f"Unknown modified_aug_loss_strategy: {modified_aug_loss}")

    return aug_loss


def get_normalized_probs(logits, log_probs):
    """Get normalized probabilities (or log probs) from a net's output."""
    return F.log_softmax(logits.float(), dim=-1) if log_probs else F.softmax(logits.float(), dim=-1)


def compute_kl_loss_Modified(src_logits, tgt_logits, reduction="batchmean"):
    p = get_normalized_probs(src_logits, log_probs=True)
    q = get_normalized_probs(tgt_logits, log_probs=True)

    p_loss = F.kl_div(p, q, reduction=reduction, log_target=True)
    q_loss = F.kl_div(q, p, reduction=reduction, log_target=True)

    return (p_loss + q_loss) / 2


def compute_js_loss_Modified(src_logits, tgt_logits, reduction="batchmean"):
    assert False, "not checked yet"

    p = get_normalized_probs(src_logits, log_probs=True)
    p_tec = get_normalized_probs(src_logits, log_probs=False)
    q = get_normalized_probs(tgt_logits, log_probs=True)
    q_tec = get_normalized_probs(tgt_logits, log_probs=False)

    ave_tec = (p_tec + q_tec) / 2
    p_loss = F.kl_div(p, ave_tec, reduction=reduction)
    q_loss = F.kl_div(q, ave_tec, reduction=reduction)

    return (p_loss + q_loss) / 2
