import numpy as np
import random
import time
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.linalg as LA
import torch.distributed as dist
from torch.nn.init import xavier_uniform_, xavier_normal_, constant_
import copy

from libs.layers import CudaVariable, GaussianNoise
import nmt_const as Const

from nmt_model_admin import myEmbedding, myLinear, FeedForward, PositionalEncoding,\
                             MultiHeadAttention

import admin_torch

use_cuda = torch.cuda.is_available()
device=torch.device("cuda" if use_cuda else "cpu")

class LabelSmoothingCrossEntropy(nn.Module):
    # Referenced from 
    # https://github.com/seominseok0429/label-smoothing-visualization-pytorch

    def __init__(self, epsilon):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.epsilon = epsilon

    def forward(self, x, target):
        smoothing = self.epsilon
        confidence = 1. - smoothing
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = confidence * nll_loss + smoothing * smooth_loss
        return loss

def get_scale(nin, nout):
    return math.sqrt(6)/math.sqrt(nin+nout) # Xavier

def one_hot(input, class_tensor, num_classes):
    Bn, Tx = input.size()
    input = input.reshape(-1).unsqueeze(1)
    return (input == class_tensor.reshape(1, num_classes)).float().view(Bn, Tx, -1)

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dk, drop_p=0., grad_monitor=False):
        super(ScaledDotProductAttention, self).__init__()
        self.temper = float(dk) ** 0.5
        self.dropout = nn.Dropout(p=drop_p)

        if grad_monitor is True:
            self.attn_logit_grad_monitor = nn.parameter.Parameter(torch.zeros((500,),\
                                                                 dtype=torch.float))
        else:
            self.attn_logit_grad_monitor = None

    def compute_attn_score(self, q, k, mask=None): # B H T E
        attn = torch.matmul(q, k.transpose(-2, -1)) / self.temper # B H Tq Tk
        if self.attn_logit_grad_monitor is not None:
            _, _, _, T = attn.size()
            attn += self.attn_logit_grad_monitor[:T]

        if mask is not None:
            attn = attn.masked_fill(mask<0.1, float('-inf'))

        attn = F.softmax(attn, dim=-1)
        return attn

    def forward(self, q, k, v, attn_neg, lambdas, mask=None):
        attn = self.compute_attn_score(q, k, mask=mask)

        attn = (1+lambdas[0])*attn - (lambdas[1])*attn_neg
        attn = self.dropout(attn)

        #print("Min Max values of Attention")
        #print(torch.min(attn))
        #print(torch.max(attn))

        output = torch.matmul(attn, v)
        return output, attn

class MultipleParallelLinear(nn.Linear):
    def __init__(self, n_feature, n_multi, out_feature, bias=False):
        super(MultipleParallelLinear, self).__init__(n_feature, n_feature, bias=bias)
        if n_feature % n_multi != 0:
            raise SyntaxError("n_feature must be divided up by n_multi")

        self.n_feature = n_feature # Total features: n_head*ind_feature
        self.n_multi = n_multi
        self.ind_feature = n_feature // n_multi
        self.out_feature = out_feature
        #self.weight = nn.parameter.Parameter(\
        #                    torch.empty((self.ind_feature, n_feature)))
        self.weight = nn.parameter.Parameter(\
                            torch.empty((self.out_feature, n_feature)))

        #self.weight_matrix_mask = CudaVariable(torch.zeros((self.n_feature, self.n_feature),\
        #                                                    dtype=torch.float))
        self.weight_matrix_mask = CudaVariable(torch.zeros((self.out_feature*self.n_multi,\
                                                            self.n_feature),\
                                                            dtype=torch.float))

        for n in range(self.n_multi):
            #self.weight_matrix_mask[n*self.ind_feature:(n+1)*self.ind_feature,\
            #                   n*self.ind_feature:(n+1)*self.ind_feature] = 1.0
            self.weight_matrix_mask[n*self.out_feature:(n+1)*self.out_feature,\
                               n*self.ind_feature:(n+1)*self.ind_feature] = 1.0
        self.reset_parameters()

    def build_parallel_weight(self):
        return self.weight_matrix_mask * self.weight.repeat(self.n_multi, 1)

    def forward(self, input):
        weight_matrix = self.build_parallel_weight()
        return F.linear(input, weight_matrix, self.bias)

class MultiHeadAttention_Neg(nn.Module):
    def __init__(self, n_head, dim_model, dk, dv, negatt_mode='const',\
                         pos_lambda=1.0, neg_lambda=1.0, neg_key=False, drop_p=0., grad_monitor=False):
        super(MultiHeadAttention_Neg, self).__init__()
        #print("This is bug version of F.relu()")
        self.n_head, self.dk, self.dv = n_head, dk, dv
        self.neg_key = neg_key

        self.mat_qs = myLinear(dim_model, n_head*dk, bias=False)
        self.mat_neg_qs = MultipleParallelLinear(n_head*dk, n_head, dk, bias=False)
        self.mat_ks = myLinear(dim_model, n_head*dk, bias=False)
        self.mat_neg_ks = MultipleParallelLinear(n_head*dk, n_head, dk, bias=False)\
                             if neg_key is True else None
        self.mat_vs = myLinear(dim_model, n_head*dv, bias=False)

        self.negatt_mode = negatt_mode
        if self.negatt_mode == 'param':
            self.neg_lambda = nn.parameter.Parameter(torch.ones((self.n_head,),\
                                                     dtype=torch.float), requires_grad=True)
        elif self.negatt_mode in ['separam', 'gseparam']:
            self.pos_lambda = nn.parameter.Parameter(torch.ones((self.n_head,),\
                                                     dtype=torch.float), requires_grad=True)
            self.neg_lambda = nn.parameter.Parameter(torch.ones((self.n_head,),\
                                                     dtype=torch.float), requires_grad=True)
        elif self.negatt_mode in ['const', 'lal', 'lal_inv']:
            self.pos_lambda = pos_lambda
            self.neg_lambda = neg_lambda
        elif self.negatt_mode == 'reg_param':
            self.mat_lambdas = nn.Sequential(nn.ReLU(),\
                                        MultipleParallelLinear(n_head*dv, n_head, 1, bias=False))
        elif self.negatt_mode == 'reg_separam':
            self.mat_lambdas = nn.Sequential(nn.ReLU(),\
                                        MultipleParallelLinear(n_head*dv, n_head, 2, bias=False))

        self.sdp_attn = ScaledDotProductAttention(dk, drop_p=drop_p, grad_monitor=grad_monitor)

        self.out_proj = myLinear(n_head*dv, dim_model)
        self.dropout = nn.Dropout(p=drop_p)

    def forward(self, q, k, v, attn_mask=None): # (QK')V # q = B T E
        n_head, dk, dv = self.n_head, self.dk, self.dv
        Bn, Tq, Tk, Tv = q.size(0), q.size(1), k.size(1), v.size(1)

        # TODO: merge the matrices and project q,k,v at the same time and split
        q_new = self.mat_qs(q) # B, T, H*d
        q_neg = self.mat_neg_qs(F.relu(q_new)) # B, T, H*d 
        #q_neg = F.relu(self.mat_neg_qs(q_new)) # B, T, H*d 
        q_new = q_new.view(Bn, Tq, n_head, dk).transpose(1,2)
        q_neg = q_neg.view(Bn, Tq, n_head, dk).transpose(1,2)

        if self.neg_key is False:
            k_new = self.mat_ks(k).view(Bn, Tk, n_head, dk).transpose(1,2)
            k_neg = k_new
        else:
            k_new = self.mat_ks(k)
            k_neg = self.mat_neg_ks(F.relu(k_new))
            k_new = k_new.view(Bn, Tk, n_head, dk).transpose(1,2)
            k_neg = k_neg.view(Bn, Tk, n_head, dk).transpose(1,2)

        if self.negatt_mode not in ['reg_param', 'reg_separam']:
            v = self.mat_vs(v).view(Bn, Tv, n_head, dv).transpose(1,2)
        else:
            v = self.mat_vs(v) # B, Tv=Tk, H*dv

        if self.negatt_mode in ['const']:
            pos_lambdas = self.pos_lambda
            neg_lambdas = self.neg_lambda
            lambdas = [pos_lambdas, neg_lambdas]
        elif self.negatt_mode in ['param']:
            neg_lambdas = torch.square(self.neg_lambda)
            #neg_lambdas = torch.abs(self.neg_lambda)
            neg_lambdas = neg_lambdas.reshape(1,-1,1,1).repeat(Bn,1,Tq,Tk)
            lambdas = [neg_lambdas, neg_lambdas]
        elif self.negatt_mode in ['separam', 'gseparam']:
            pos_lambdas = torch.square(self.pos_lambda)
            pos_lambdas = pos_lambdas.reshape(1,-1,1,1).repeat(Bn,1,Tq,Tk)
            neg_lambdas = torch.square(self.neg_lambda)
            neg_lambdas = neg_lambdas.reshape(1,-1,1,1).repeat(Bn,1,Tq,Tk)
            lambdas = [pos_lambdas, neg_lambdas]
        elif self.negatt_mode in ['reg_param']:
            if attn_mask.size(1) == 1:
                tmp_attn_mask = copy.deepcopy(attn_mask).type(torch.float)
                tmp_attn_mask = tmp_attn_mask.transpose(-1,-2).repeat(1,1,Tk).transpose(-1,-2)
            else:
                tmp_attn_mask = copy.deepcopy(attn_mask).type(torch.float)
            v_lambda_in = torch.bmm(tmp_attn_mask, v) / tmp_attn_mask.sum(-1).unsqueeze(-1).repeat(1,1,v.size(-1)) # B, Tv=Tk, H*dv
            neg_lambdas = self.mat_lambdas(v_lambda_in) # B, Tk, H
            neg_lambdas = torch.square(neg_lambdas).transpose(-1,-2) # B, H, Tk
            neg_lambdas = neg_lambdas.reshape(Bn,n_head,Tk,1).repeat(1,1,1,Tq).transpose(-1,-2) # B, H, Tq, Tk
            lambdas = [neg_lambdas, neg_lambdas]
            #q_new = q_new.view(Bn, Tq, n_head, dk).transpose(1,2)
            v = v.view(Bn, Tv, n_head, dv).transpose(1,2)
        elif self.negatt_mode in ['reg_separam']:
            if attn_mask.size(1) == 1:
                tmp_attn_mask = copy.deepcopy(attn_mask).type(torch.float)
                tmp_attn_mask = tmp_attn_mask.transpose(-1,-2).repeat(1,1,Tk).transpose(-1,-2)
            else:
                tmp_attn_mask = copy.deepcopy(attn_mask).type(torch.float)
            v_lambda_in = torch.bmm(tmp_attn_mask, v) / tmp_attn_mask.sum(-1).unsqueeze(-1).repeat(1,1,v.size(-1)) # B, Tv=Tk, H*dv
            tmp_lambdas = self.mat_lambdas(v_lambda_in).reshape(Bn, Tk, n_head, 2)
            neg_lambdas = torch.square(tmp_lambdas[:,:,:,0]).transpose(-1,-2) # B, H, Tk
            pos_lambdas = torch.square(tmp_lambdas[:,:,:,1]).transpose(-1,-2) # B, H, Tk
            neg_lambdas = neg_lambdas.reshape(Bn,n_head,Tk,1).repeat(1,1,1,Tq).transpose(-1,-2) # B, H, Tq, Tk
            pos_lambdas = pos_lambdas.reshape(Bn,n_head,Tk,1).repeat(1,1,1,Tq).transpose(-1,-2) # B, H, Tq, Tk

            lambdas = [pos_lambdas, neg_lambdas]
            #q_new = q_new.view(Bn, Tq, n_head, dk).transpose(1,2)
            v = v.view(Bn, Tv, n_head, dv).transpose(1,2)


        if attn_mask is not None: #  B ? T -> B ? ? T 
            attn_mask = attn_mask.unsqueeze(1)

        attn_neg = self.sdp_attn.compute_attn_score(q_neg, k_neg, mask=attn_mask)

        output, attn = self.sdp_attn(q_new, k_new, v, attn_neg,\
                                     lambdas, mask=attn_mask) # (Bn H Ty E), (Bn, H, Tq, Tk)
        output = output.transpose(1, 2).contiguous().view(Bn, Tq, -1) # Bn Ty H*E
        output = self.dropout(self.out_proj(output))

        return output, attn

    def forward_state_monitor(self, q, k, v, attn_mask=None): # (QK')V # q = B T E
        n_head, dk, dv = self.n_head, self.dk, self.dv
        n_head_neg = self.n_head_neg
        Bn, Tq, Tk, Tv = q.size(0), q.size(1), k.size(1), v.size(1)

        # TODO: merge the matrices and project q,k,v at the same time and split
        q_new = self.mat_qs(q) # B, T, H*d
        q_neg = self.mat_neg_qs(F.relu(q_new)) # B, T, H*d 
        q_new = q_new.view(Bn, Tq, n_head, dk).transpose(1,2)
        q_neg = q_neg.view(Bn, Tq, n_head, dk).transpose(1,2)

        if self.neg_key is False:
            k_new = self.mat_ks(k).view(Bn, Tk, n_head, dk).transpose(1,2)
            k_neg = k_new
        else:
            k_new = self.mat_ks(k)
            k_neg = self.mat_neg_ks(F.relu(k_new))
            k_new = k_new.view(Bn, Tk, n_head, dk).transpose(1,2)
            k_neg = k_neg.view(Bn, Tk, n_head, dk).transpose(1,2)

        v = self.mat_vs(v).view(Bn, Tv, n_head, dv).transpose(1,2)

        if self.negatt_mode in ['const']:
            pos_lambdas = self.pos_lambda
            neg_lambdas = self.neg_lambda
            lambdas = [pos_lambdas, neg_lambdas]
        elif self.negatt_mode in ['param']:
            neg_lambdas = torch.square(self.neg_lambda)
            #neg_lambdas = torch.abs(self.neg_lambda)
            neg_lambdas = neg_lambdas.reshape(1,-1,1,1).repeat(Bn,1,Tq,Tk)
            lambdas = [neg_lambdas, neg_lambdas]
        elif self.negatt_mode in ['separam', 'gseparam']:
            pos_lambdas = torch.square(self.pos_lambda)
            pos_lambdas = pos_lambdas.reshape(1,-1,1,1).repeat(Bn,1,Tq,Tk)
            neg_lambdas = torch.square(self.neg_lambda)
            neg_lambdas = neg_lambdas.reshape(1,-1,1,1).repeat(Bn,1,Tq,Tk)
            lambdas = [pos_lambdas, neg_lambdas]


        if attn_mask is not None: #  B ? T -> B ? ? T 
            attn_mask = attn_mask.unsqueeze(1)

        attn_neg = self.sdp_attn.compute_attn_score(q_neg, k_neg, mask=attn_mask)

        sdp_output, attn = self.sdp_attn(q_new, k_new, v, attn_neg,\
                                     lambdas, mask=attn_mask) # (Bn H Ty E), (Bn, H, Tq, Tk)
        sdp_output = sdp_output.transpose(1, 2).contiguous().view(Bn, Tq, -1) # Bn Ty H*E
        final_output = self.dropout(self.out_proj(sdp_output))

        return final_output, sdp_output, attn




class TM_EncoderLayer(nn.Module):
    def __init__(self, n_layers, dim_model, dim_ff, n_head, dk, dv, n_head_neg=1,\
                     negatt_mode='const', pos_lambda=1.0, neg_lambda=1.0, negatt_apply='full',\
                     neg_key=False, drop_p=0., grad_monitor=False):
        super(TM_EncoderLayer, self).__init__()
        if negatt_apply in ['full', 'enc', 'encdecca']:
            self.self_attn = MultiHeadAttention_Neg(n_head, dim_model, dk, dv,\
                                                 negatt_mode=negatt_mode, pos_lambda=pos_lambda,\
                                                 neg_lambda=neg_lambda,\
                                                 neg_key=neg_key, drop_p=drop_p,\
                                                 grad_monitor=grad_monitor)
        else:
            self.self_attn = MultiHeadAttention(n_head, dim_model, dk, dv, drop_p=drop_p)
        self.self_layer_norm = nn.LayerNorm(dim_model)
        admin_torch.as_parameter(self, 'self_attn_omega', 2*n_layers, dim_model)

        self.ff_layer = FeedForward(dim_model, dim_ff, drop_p=drop_p)
        self.ff_layer_norm = nn.LayerNorm(dim_model)
        admin_torch.as_parameter(self, 'ff_omega', 2*n_layers, dim_model)

    def forward(self, enc_in, x_mask=None):
        B, T, E = enc_in.size()

        attn_out, attn = self.self_attn(enc_in, enc_in, enc_in, attn_mask=x_mask) # Attention
        attn_out = attn_out + enc_in*self.self_attn_omega.reshape(1,1,-1).repeat(B,T,1) # residual connection
        attn_out_norm = self.self_layer_norm(attn_out) # layer norm

        ff_out = self.ff_layer(attn_out_norm) # ff
        ff_out = ff_out + attn_out_norm*self.ff_omega.reshape(1,1,-1).repeat(B,T,1) # residual connection
        ff_out_norm = self.ff_layer_norm(ff_out) # layer norm
        return ff_out_norm, attn

    def forward_monitor(self, enc_in, x_mask=None):
        raise NotImplementedError



class TM_Encoder(nn.Module):
    def __init__(self, src_words_n, n_layers=6, n_head=8, dk=64, dv=64,\
                    dim_wemb=512, dim_model=512, dim_ff=1024, drop_p=0., emb_noise=0., max_len=250,\
                    n_head_neg=1, negatt_mode='const', pos_lambda=1.0, neg_lambda=1.0,\
                    negatt_apply='full', neg_key=False):
        super(TM_Encoder, self).__init__()

        # first in : 
        self.src_emb = myEmbedding(src_words_n, dim_wemb)#, padding_idx=Const.PAD)
        self.pos_enc = PositionalEncoding(dim_wemb, max_len, drop_p)
        # repeated layer
        if negatt_mode in ['const', 'learnable', 'learnable_relu', 'param', 'separam', 'reg_param',\
                            'reg_separam', 'gseparam']:
            self.layer_stack = nn.ModuleList([
                TM_EncoderLayer(n_layers, dim_model, dim_ff, n_head, dk, dv, n_head_neg=n_head_neg,\
                                 negatt_mode=negatt_mode, pos_lambda=pos_lambda,\
                                 neg_lambda=neg_lambda,\
                                 negatt_apply=negatt_apply, neg_key=neg_key, drop_p=drop_p)\
                                 for _ in range(n_layers)])
            if negatt_mode == 'gseparam':
                main_pos_lambda = self.layer_stack[0].self_attn.pos_lambda
                main_neg_lambda = self.layer_stack[0].self_attn.neg_lambda
                for i in range(1, n_layers):
                    self.layer_stack[i].self_attn.pos_lambda = main_pos_lambda
                    self.layer_stack[i].self_attn.neg_lambda = main_neg_lambda
        elif negatt_mode == 'lal':
            raise SyntaxError("'lal' is deprecated by the addition of pos_lambda configuration")
            layers = []
            for i in range(n_layers):
                tmp_neg_lambda = neg_lambda*((i+1)/n_layers)
                layers.append(TM_EncoderLayer(n_layers, dim_model, dim_ff, n_head, dk, dv,\
                                n_head_neg=n_head_neg, negatt_mode=negatt_mode,\
                                neg_lambda=tmp_neg_lambda, negatt_apply=negatt_apply,\
                                neg_key=neg_key, drop_p=drop_p))
            self.layer_stack = nn.ModuleList(layers)
        elif negatt_mode == 'lal_inv':
            raise SyntaxError("'lal_inv' is deprecated by the addition of pos_lambda configuration")
            layers = []
            for i in range(n_layers):
                tmp_neg_lambda = neg_lambda*(1 - (i/n_layers))
                layers.append(TM_EncoderLayer(n_layers, dim_model, dim_ff, n_head, dk, dv,\
                                n_head_neg=n_head_neg, negatt_mode=negatt_mode,\
                                neg_lambda=tmp_neg_lambda, negatt_apply=negatt_apply,\
                                neg_key=neg_key, drop_p=drop_p))
            self.layer_stack = nn.ModuleList(layers)


    def forward(self, src_seq, src_mask):
        src_mask = src_mask.unsqueeze(1) # B 1 T
        enc_out = self.src_emb(src_seq) # Word embedding look up # Bn Tx Emb
        enc_out = self.pos_enc(enc_out) # Position Encoding

        for enc_layer in self.layer_stack:
            enc_out, _ = enc_layer(enc_out, x_mask=src_mask)
        return enc_out

    def forward_monitor(self, src_seq, src_mask):
        raise NotImplementedError

    def forward_attn_monitor(self, src_seq, src_mask):
        raise NotImplementedError

    def forward_svd_monitor(self, src_seq, src_mask):
        raise NotImplementedError

    def forward_rescon_monitor(self, src_seq, src_mask):
        raise NotImplementedError



class TM_DecoderLayer(nn.Module):
    def __init__(self, n_layers, dim_model, dim_ff, n_head, dk, dv,\
                         n_head_neg=1, negatt_mode='const',\
                         pos_lambda=1.0, neg_lambda=1.0, negatt_apply='full', neg_key=False,\
                         drop_p=0.):
        super(TM_DecoderLayer, self).__init__()
        if negatt_apply in ['full', 'decsa']:
            self.self_attn = MultiHeadAttention_Neg(n_head, dim_model, dk, dv,\
                                                negatt_mode=negatt_mode, pos_lambda=pos_lambda,\
                                                neg_lambda=neg_lambda, neg_key=neg_key,\
                                                drop_p=drop_p)
        else:
            self.self_attn = MultiHeadAttention(n_head, dim_model, dk, dv, drop_p=drop_p)
        self.self_layer_norm = nn.LayerNorm(dim_model)
        admin_torch.as_parameter(self, 'self_attn_omega', 3*n_layers, dim_model)

        if negatt_apply in ['full', 'decca', 'encdecca']:
            self.cross_attn = MultiHeadAttention_Neg(n_head, dim_model, dk, dv,\
                                                negatt_mode=negatt_mode, pos_lambda=pos_lambda,\
                                                neg_lambda=neg_lambda, neg_key=neg_key,\
                                                drop_p=drop_p)
        else:
            self.cross_attn = MultiHeadAttention(n_head, dim_model, dk, dv, drop_p=drop_p)
        self.cross_layer_norm = nn.LayerNorm(dim_model)
        admin_torch.as_parameter(self, 'cross_attn_omega', 3*n_layers, dim_model)

        self.ff_layer = FeedForward(dim_model, dim_ff, drop_p=drop_p)
        self.ff_layer_norm = nn.LayerNorm(dim_model)
        admin_torch.as_parameter(self, 'ff_omega', 3*n_layers, dim_model)

    def forward(self, dec_in, enc_out, y_mask=None, enc_mask=None):
        B, T, E = dec_in.size()

        self_attn_out, self_attn = self.self_attn(dec_in, dec_in, dec_in, attn_mask=y_mask)
        self_attn_out = self_attn_out + dec_in*self.self_attn_omega.reshape(1,1,-1).repeat(B,T,1)
        self_attn_out_norm = self.self_layer_norm(self_attn_out)

        cross_attn_out, cross_attn = self.cross_attn(self_attn_out_norm, enc_out,\
                                                   enc_out, attn_mask=enc_mask)
        cross_attn_out = cross_attn_out + self_attn_out_norm*self.cross_attn_omega.reshape(1,1,-1).repeat(B,T,1)
        cross_attn_out_norm = self.cross_layer_norm(cross_attn_out)

        ff_out = self.ff_layer(cross_attn_out_norm)
        ff_out = ff_out + cross_attn_out_norm*self.ff_omega.reshape(1,1,-1).repeat(B,T,1)
        ff_out_norm = self.ff_layer_norm(ff_out)
        return ff_out_norm, self_attn, cross_attn

    def forward_state_monitor(self, dec_in, enc_out, y_mask=None, enc_mask=None):
        raise NotImplementedError

    def forward_rescon_monitor(self, dec_in, enc_out, y_mask=None, enc_mask=None):
        raise NotImplementedError

class TM_Decoder(nn.Module):
    def __init__(self, trg_words_n, n_layers=6, n_head=8, dk=64, dv=64,\
                dim_wemb=512, dim_model=512, dim_ff=1024, drop_p=0., emb_noise=0., max_len=250,\
                n_head_neg=1, negatt_mode='const', pos_lambda=1.0, neg_lambda=1.0,\
                negatt_apply='full', neg_key=False):
        super(TM_Decoder, self).__init__()

        # first in
        self.dec_emb = myEmbedding(trg_words_n, dim_wemb)#, padding_idx=Const.PAD)
        self.pos_enc = PositionalEncoding(dim_wemb, max_len, drop_p)
        # repeated layer
        if negatt_mode in ['const', 'learnable', 'learnable_relu', 'param', 'separam', 'reg_param',\
                            'reg_separam', 'gseparam']:
            self.layer_stack = nn.ModuleList([
                TM_DecoderLayer(n_layers, dim_model, dim_ff, n_head, dk, dv, n_head_neg=n_head_neg,\
                                 negatt_mode=negatt_mode, pos_lambda=pos_lambda,\
                                 neg_lambda=neg_lambda,\
                                 negatt_apply=negatt_apply, neg_key=neg_key, drop_p=drop_p)\
                                 for _ in range(n_layers)])
            if negatt_mode == 'gseparam':
                main_pos_lambda = self.layer_stack[0].self_attn.pos_lambda
                main_neg_lambda = self.layer_stack[0].self_attn.neg_lambda
                for i in range(1, n_layers):
                    self.layer_stack[i].self_attn.pos_lambda = main_pos_lambda
                    self.layer_stack[i].self_attn.neg_lambda = main_neg_lambda

                main_pos_lambda = self.layer_stack[0].cross_attn.pos_lambda
                main_neg_lambda = self.layer_stack[0].cross_attn.neg_lambda
                for i in range(1, n_layers):
                    self.layer_stack[i].cross_attn.pos_lambda = main_pos_lambda
                    self.layer_stack[i].cross_attn.neg_lambda = main_neg_lambda

        elif negatt_mode == 'lal':
            raise SyntaxError("'lal' is deprecated by the addition of pos_lambda configuration")
            layers = []
            for i in range(n_layers):
                tmp_neg_lambda = neg_lambda*((i+1)/n_layers)
                layers.append(TM_DecoderLayer(n_layers, dim_model, dim_ff, n_head, dk, dv,\
                                n_head_neg=n_head_neg, negatt_mode=negatt_mode,\
                                neg_lambda=tmp_neg_lambda, negatt_apply=negatt_apply,\
                                neg_key=neg_key, drop_p=drop_p))
            self.layer_stack = nn.ModuleList(layers)
        elif negatt_mode == 'lal_inv':
            raise SyntaxError("'lal_inv' is deprecated by the addition of pos_lambda configuration")
            layers = []
            for i in range(n_layers):
                tmp_neg_lambda = neg_lambda*(1 - (i/n_layers))
                layers.append(TM_DecoderLayer(n_layers, dim_model, dim_ff, n_head, dk, dv,\
                                n_head_neg=n_head_neg, negatt_mode=negatt_mode,\
                                neg_lambda=tmp_neg_lambda, negatt_apply=negatt_apply,\
                                neg_key=neg_key, drop_p=drop_p))
            self.layer_stack = nn.ModuleList(layers)


    def get_subsequent_mask(self, seq):
        if len(seq.size()) == 2: # seq is scalar values
            s0, s1 = seq.size()
        else:
            s0, s1, _ = seq.size() # seq is one-hot processed values
        mask = torch.tril(torch.ones((s1, s1), device=seq.device), diagonal=0)
        return mask.type(torch.cuda.FloatTensor).unsqueeze(0)   

    def forward(self, trg_seq, trg_mask, enc_out, src_mask):
        src_mask = src_mask.unsqueeze(1) # B 1 T
        trg_mask = trg_mask.unsqueeze(1) # B 1 T
        dec_out = self.dec_emb(trg_seq) # Word embedding look up
        dec_out = self.pos_enc(dec_out) # Posision Encoding

        mh_attn_sub_mask = self.get_subsequent_mask(trg_seq) # lower traingle matrix 
        y_mask = trg_mask * mh_attn_sub_mask

        for dec_layer in self.layer_stack:
            dec_out, _, _ = dec_layer(dec_out, enc_out, y_mask=y_mask, enc_mask=src_mask)

        return dec_out # B Ty E

    def forward_monitor(self, src_seq, src_mask):
        raise NotImplementedError

    def forward_attn_monitor(self, src_seq, src_mask):
        raise NotImplementedError

    def forward_svd_monitor(self, src_seq, src_mask):
        raise NotImplementedError

    def forward_rescon_monitor(self, src_seq, src_mask):
        raise NotImplementedError



class NegAtt_Admin_Transformer(nn.Module):
    def __init__(self, args=None):
        super(NegAtt_Admin_Transformer, self).__init__()
        #print("Only Decoder CA is Negative Atteniton")
        self.class_tensor = torch.arange(args.trg_words_n).cuda()
        self.trg_words_n = args.trg_words_n
        self.max_len = args.test_max_length # 121 is fairseq's maximum sentence length considers during test in IWSLT2014 (there are longer sentence but filtered out)

        src_words_n, trg_words_n = args.src_words_n, args.trg_words_n
        dim_wemb, dim_model = args.dim_wemb, args.dim_model
        drop_p = args.dropout_p
        dim_ff, n_layers = args.tm_dim_ff, args.tm_n_layers
        n_head, dk, dv = args.tm_n_head, args.tm_dk, args.tm_dv
        assert dim_model == dim_wemb, 'dim_model == dim_wemb for residual connections'


        n_head_neg = getattr(args, 'n_head_neg', 1)
        neg_lambda = getattr(args, 'neg_lambda', 1.0)
        pos_lambda = getattr(args, 'pos_lambda', 1.0)
        negatt_mode = getattr(args, 'negatt_mode', 'const')
        negatt_apply = getattr(args, 'negatt_apply', 'full')
        neg_key = bool(getattr(args, 'neg_key', 0))
        self.encoder=TM_Encoder(src_words_n, n_layers=n_layers, n_head=n_head, dk=dk, dv=dv,
                dim_wemb=dim_wemb, dim_model=dim_model, dim_ff=dim_ff, drop_p=drop_p,
                emb_noise=args.emb_noise, max_len=self.max_len, n_head_neg=n_head_neg,
                negatt_mode=negatt_mode, pos_lambda=pos_lambda, neg_lambda=neg_lambda,
                negatt_apply=negatt_apply, neg_key=neg_key)
        self.decoder=TM_Decoder(trg_words_n, n_layers=n_layers, n_head=n_head, dk=dk, dv=dv,
                dim_wemb=dim_wemb, dim_model=dim_model, dim_ff=dim_ff, drop_p=drop_p,
                emb_noise=args.emb_noise, max_len=self.max_len, n_head_neg=n_head_neg,
                negatt_mode=negatt_mode, pos_lambda=pos_lambda, neg_lambda=neg_lambda,
                negatt_apply=negatt_apply, neg_key=neg_key)
        self.logit_layer = nn.Linear(dim_model, trg_words_n)

        # Logit layer can share weight with encoder's embedding : It is better not to share
        #self.encoder.src_emb.weight = self.logit_layer.weight

        if args.joined_dictionary == 1:
            self.decoder.dec_emb.weight = self.encoder.src_emb.weight

        if args.label_smoothing > 0.0:
            self.criterion = LabelSmoothingCrossEntropy(args.label_smoothing)
        else:
            self.criterion = nn.CrossEntropyLoss(reduction='none')

        self.nll = nn.NLLLoss(reduction='none')

    def forward(self, x_data, x_mask, y_data, y_mask):
        if torch.is_tensor(x_data) == False:
            x_data = CudaVariable(torch.LongTensor(x_data)) # B T
            x_mask = CudaVariable(torch.FloatTensor(x_mask)) # B T
        if torch.is_tensor(y_data) == False:
            y_data = CudaVariable(torch.LongTensor(y_data)) # B T
            y_mask = CudaVariable(torch.FloatTensor(y_mask)) # B T

        y_target = y_data[:,1:] # label
        y_mask = y_mask[:,1:]
        y_in = y_data[:,:-1] # input as teacher forcing
        Bn, Ty = y_in.size()

        # encode and decode
        enc_out = self.encoder(x_data, x_mask) # B Tx E
        dec_out = self.decoder(y_in, y_mask, enc_out, x_mask) # B Ty E(num of words)
        out = self.logit_layer(dec_out)

        # loss
        loss = self.criterion(out.view(-1, out.size(2)), y_target.contiguous().view(-1)) 
        loss = loss * y_mask.contiguous().view(-1)

        return loss

    def decode_forward(self, enc_out, x_mask, y_data, y_mask):
        # Instead of x_data, it receives processed enc_out
        if torch.is_tensor(y_data) == False:
            y_data = CudaVariable(torch.LongTensor(y_data)) # B T
            y_mask = CudaVariable(torch.FloatTensor(y_mask)) # B T

        y_target = y_data[:,1:] # label
        y_mask = y_mask[:,1:]
        y_in = y_data[:,:-1] # input as teacher forcing
        Bn, Ty = y_in.size()

        # encode and decode
        dec_out = self.decoder(y_in, y_mask, enc_out, x_mask) # B Ty E(num of words)
        out = self.logit_layer(dec_out)

        # loss
        loss = self.criterion(out.view(-1, out.size(2)), y_target.contiguous().view(-1)) 
        loss = loss * y_mask.contiguous().view(-1)

        return loss

    def sample_idx(self, probs, train, random_mode=False):
        _, index = probs.topk(1, dim=1) # Bn, 1
        return index

    def translate(self, enc_out, x_mask, CRT_coeff, train=True,\
                                 start_time=time.time()):
        Bn, Tx = x_mask.size()

        pad = (torch.ones((Bn,1))*Const.PAD).type(torch.long).cuda()
        pad_onehot = one_hot(pad, self.class_tensor, self.trg_words_n).squeeze() # Bn, C

        EOSs = torch.zeros((Bn, 1)).cuda()

        y_hat0 = (torch.ones((Bn,1))*Const.BOS).type(torch.long).cuda()
        dec_seq = copy.deepcopy(y_hat0)
        y_hat = one_hot(y_hat0, self.class_tensor, self.trg_words_n)
        dec_mask = CudaVariable(torch.ones((Bn,1))).type(torch.cuda.LongTensor)

        tmp_max_len = self.max_len

        times = torch.ones(Bn)*(time.time() - start_time)
        for yi in range(tmp_max_len):

            dec_out = self.decoder(dec_seq, dec_mask, enc_out, x_mask) # Bn, Tx, C
            dec_out = self.logit_layer(dec_out)
            Bn, T, C = dec_out.size()
            tmp_dec_out = dec_out.reshape(Bn*T, C)
            probs = F.softmax(tmp_dec_out, dim=1) # Bn*T, C
            index = self.sample_idx(probs, train=train) # Bn*T, 1

            last_index = index.reshape(Bn, T, 1)[:,-1,:]
            tmp_EOSs = torch.gt(EOSs,0).type(torch.long)
            tmp_dec_seq = (1-tmp_EOSs)*last_index + (tmp_EOSs)*pad
            #dec_seq = torch.cat((dec_seq, last_index), dim=1)
            dec_seq = torch.cat((dec_seq, tmp_dec_seq), dim=1) # Bn, T+1
            dec_mask = torch.cat((dec_mask, (1-tmp_EOSs)), dim=1) # Bn, T+1

            #EOS1 = torch.eq(last_index, Const.EOS).view(Bn, 1)
            EOS1 = torch.eq(tmp_dec_seq, Const.EOS).view(Bn, 1)
            for b in range(Bn):
                if EOSs[b] > 0:
                    continue
                elif EOS1[b] > 0:
                    times[b] = time.time() - start_time
            EOSs = EOSs + EOS1
            if yi > 0 and torch.sum(torch.gt(EOSs,0)) >= Bn:
                break

        for b in range(Bn):
            if EOSs[b] > 0:
                continue
            times[b] = time.time() - start_time

        return dec_seq, dec_mask, enc_out, times


class IndNegAtt_Admin_NMT(nn.Module):
    def __init__(self, args=None, mean_batch=1):
        super(IndNegAtt_Admin_NMT, self).__init__()
        self.src_lang = args.src_lang
        self.trg_lang = args.trg_lang

        self.model = NegAtt_Admin_Transformer(args=args)

        self.batch_sizes = []
        self.mean_batch = mean_batch

        self.test_max_len = args.test_max_length

    def compute_norm(self, variable, mask):
        B, T, E = variable.size()
        norm = LA.norm(variable, dim=2)
        norm = (mask*norm).sum() / (mask.sum())
        return norm

    def compute_std(self, variable, mask):
        B, T, E = variable.size()
        mask = mask.unsqueeze(-1).repeat(1,1,E)

        variable = (mask*variable).sum(dim=1) / (mask.sum(dim=1))

        std = torch.std(variable, dim=0)
        std = std.mean()
        return std

    def compute_mean_batch_sizes(self, new_B):
        self.batch_sizes.append(new_B)

        if len(self.batch_sizes) > 1000: # 100 is arbitrarily selected number
            self.batch_sizes = self.batch_sizes[-1000:]

        self.mean_batch = np.mean(np.array(self.batch_sizes))

    def translation(self, x_data, x_mask, y_data, y_mask):
        enc_out = self.model.encoder(x_data, x_mask)

        trans_loss = self.model.decode_forward(enc_out, x_mask, y_data, y_mask)
        trans_loss = torch.sum(trans_loss)/self.mean_batch
        return trans_loss, enc_out

    def add_eos_padding(self, sample, B_max, max_len):
        B, T = sample.size()

        if B < B_max:
            add_batch_eos = torch.ones((B_max-B, T), dtype=torch.long).cuda() * Const.EOS
            sample = torch.cat((sample, add_batch_eos), dim=0) # B_max, T
        if T < max_len:
            add_time_eos = torch.ones((B_max, max_len-T), dtype=torch.long).cuda() * Const.EOS
            sample = torch.cat((sample, add_time_eos), dim=1) # B_max, max_len
        return sample

    def generation(self, x_data, x_mask, multi_gpu=False, B_max=0, max_len=0):
        start_time = time.time()
        enc_out = self.model.encoder(x_data, x_mask)

        gen_y_data, gen_y_mask, _, times =\
                         self.model.translate(enc_out, x_mask, 1.0,\
                                                             train=False, start_time=start_time)
        if multi_gpu == False:
            return gen_y_data.detach().cpu().numpy()[:,1:],\
                    enc_out, times
        else:
            gen_y_data = gen_y_data.detach()[:,1:]
            gen_y_data = self.add_eos_padding(gen_y_data, B_max, max_len)
            return gen_y_data, enc_out, times

    def forward(self, *inputs, **kwargs):
        mode = inputs[-1]
        if mode == 'training':
            self.model.train()

            x_data, x_mask, y_data, y_mask = inputs[0],inputs[1],inputs[2],inputs[3]

            x_data = CudaVariable(torch.LongTensor(x_data))
            x_mask = CudaVariable(torch.LongTensor(x_mask))
            y_data = CudaVariable(torch.LongTensor(y_data))
            y_mask = CudaVariable(torch.LongTensor(y_mask))

            B, _ = x_data.size()
            self.compute_mean_batch_sizes(B)

            trans_loss, enc_out = self.translation(x_data, x_mask, y_data, y_mask)

            latent_norms = self.compute_norm(enc_out,x_mask)
            latent_stds = self.compute_std(enc_out,x_mask)
            return (trans_loss)
        elif mode == 'validation':
            self.model.eval()

            x_data, x_mask, y_data, y_mask = inputs[0],inputs[1],inputs[2],inputs[3]

            x_data = CudaVariable(torch.LongTensor(x_data))
            x_mask = CudaVariable(torch.LongTensor(x_mask))
            y_data = CudaVariable(torch.LongTensor(y_data))
            y_mask = CudaVariable(torch.LongTensor(y_mask))

            gen_y_data, enc_out, times = self.generation(x_data, x_mask)

            avg_time = times.mean()

            trans_loss = self.model.decode_forward(enc_out, x_mask, y_data, y_mask)
            trans_loss = torch.sum(trans_loss)/self.mean_batch
            return (gen_y_data, trans_loss, avg_time)
        elif mode == 'multi_validation':
            self.model.eval()

            x_data, x_mask, y_data, y_mask, B_max = inputs[0],inputs[1],inputs[2],inputs[3],inputs[4]

            x_data = CudaVariable(torch.LongTensor(x_data))
            x_mask = CudaVariable(torch.LongTensor(x_mask))
            y_data = CudaVariable(torch.LongTensor(y_data))
            y_mask = CudaVariable(torch.LongTensor(y_mask))

            gen_y_data, enc_out, times = self.generation(x_data, x_mask, multi_gpu=True,\
                                                         B_max=B_max, max_len=self.test_max_len)

            avg_time = times.mean()

            trans_loss = self.model.decode_forward(enc_out, x_mask, y_data, y_mask)
            trans_loss = torch.sum(trans_loss)/self.mean_batch

            y_data = self.add_eos_padding(y_data, B_max, self.test_max_len)

            return (gen_y_data, trans_loss, avg_time, y_data)
