from fairseq.models.transformer import *
from types import MethodType
import copy
from torch.autograd import Variable
from torch.autograd import Function
from typing import Any, Dict, List, NamedTuple, Optional
from torch import Tensor
import random








class Grad_switch(Function):
    def forward(self, Fout,mul):

        self.save_for_backward(Fout,mul)
        return Fout

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        Fout,mul = self.saved_tensors
        grad_mul=torch.zeros_like(mul)

        return grad_output * mul,grad_mul




def f_Grad_switch(Fout,mul):
    # First braces create a Function object. Any arguments given here
    # will be passed to __init__. Second braces will invoke the __call__
    # operator, that will then use forward() to compute the result and
    # return it.
    return Grad_switch()(Fout,mul)

























def Append_TransformerModel(TransformerModel,Tran_for,en_for,de_for,de_ext):
    TransformerModel.encoder.forward= MethodType(en_for, TransformerModel.encoder)
    TransformerModel.decoder.extract_features = MethodType(de_ext, TransformerModel.decoder)
    TransformerModel.decoder.forward = MethodType(de_for, TransformerModel.decoder)


    TransformerModel.forward = MethodType(Tran_for, TransformerModel)
    return TransformerModel





def Append_encoder_forward(
        self,
        src_tokens,
        src_lengths,
        Encoder_List: Optional[nn.ModuleList] = None,
        Encoder_addList: Optional[list] = None,
        cls_input: Optional[Tensor] = None,
                                      return_all_hiddens: bool = True,
    ):
    if Encoder_List==None:
        Encoder_List=self.encoder_modellist
        Encoder_addList = self.encoder_addList
    if self.layer_wise_attention:
        return_all_hiddens = True

    x, encoder_embedding = self.forward_embedding(src_tokens)

    # B x T x C -> T x B x C
    x = x.transpose(0, 1)

    # compute padding mask
    encoder_padding_mask = src_tokens.eq(self.padding_idx)

    encoder_states = [] if return_all_hiddens else None

    # encoder layers
    if return_all_hiddens:
        assert encoder_states is not None
        encoder_states.append(x)
    for indexI  in range(len(self.layers)):
        layer=self.layers[indexI]

    #for layer in self.layers:
        # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
        dropout_probability = torch.empty(1).uniform_()
        if not self.training or (dropout_probability > self.encoder_layerdrop):


            x = layer(x, encoder_padding_mask)
            for i in range(len(Encoder_addList[indexI][0])):
                x = Encoder_List[indexI][0][i](x, encoder_states[Encoder_addList[indexI][0][i]])
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)

    if self.layer_norm is not None:
        x = self.layer_norm(x)
        if return_all_hiddens:
            encoder_states[-1] = x

    return EncoderOut(
        encoder_out=x,  # T x B x C
        encoder_padding_mask=encoder_padding_mask,  # B x T
        encoder_embedding=encoder_embedding,  # B x T x C
        encoder_states=encoder_states,  # List[T x B x C]||EncoderLayer_list
    )#EncoderOut.encoder_states



def Append_decoder_forward(
        self,
        prev_output_tokens,
        #EncoderLayer_list,
        DecoderLayer_ModuleList: Optional[nn.ModuleList] = None,
        addlink_addList: Optional[list] = None,
        encoder_out: Optional[EncoderOut] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        features_only: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
        src_lengths: Optional[Any] = None,
        return_all_hiddens: bool = True,
    ):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_out (optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`
            features_only (bool, optional): only return features without
                applying output layer (default: False).

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        if DecoderLayer_ModuleList == None:
            DecoderLayer_ModuleList = self.decoder_modellist
            addlink_addList = self.decoder_addList
        x, extra = self.extract_features(
            prev_output_tokens,
            #EncoderLayer_list,
            DecoderLayer_ModuleList, addlink_addList,
            encoder_out=encoder_out,
            incremental_state=incremental_state,
            alignment_layer=alignment_layer,
            alignment_heads=alignment_heads,
        )
        if not features_only:
            x = self.output_layer(x)
        return x, extra






def Append_extract_features(
        self,
        prev_output_tokens,
        DecoderLayer_ModuleList: Optional[nn.ModuleList] = None,
        addlink_addList: Optional[list] = None,
        encoder_out: Optional[EncoderOut] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        full_context_alignment: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
    ):
        """
        Similar to *forward* but only return features.

        Includes several features from "Jointly Learning to Align and
        Translate with Transformer Models" (Garg et al., EMNLP 2019).

        Args:
            full_context_alignment (bool, optional): don't apply
                auto-regressive mask to self-attention (default: False).
            alignment_layer (int, optional): return mean alignment over
                heads at this layer (default: last layer).
            alignment_heads (int, optional): only average alignment over
                this many heads (default: all heads).

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        EncoderLayer_list = encoder_out.encoder_states
        if DecoderLayer_ModuleList == None:
            DecoderLayer_ModuleList = self.decoder_modellist
            addlink_addList = self.decoder_addList
        if alignment_layer is None:
            alignment_layer = self.num_layers - 1

        # embed positions
        positions = (
            self.embed_positions(
                prev_output_tokens, incremental_state=incremental_state
            )
            if self.embed_positions is not None
            else None
        )

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions

        if self.layernorm_embedding is not None:
            x = self.layernorm_embedding(x)

        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        self_attn_padding_mask: Optional[Tensor] = None
        if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
            self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)

        # decoder layers
        attn: Optional[Tensor] = None
        inner_states: List[Optional[Tensor]] = [x]
        for idx, layer in enumerate(self.layers):
            encoder_state: Optional[Tensor] = None
            if encoder_out is not None:
                if self.layer_wise_attention:
                    encoder_states = encoder_out.encoder_states
                    assert encoder_states is not None
                    encoder_state = encoder_states[idx]
                else:
                    encoder_state = encoder_out.encoder_out

            if incremental_state is None and not full_context_alignment:
                self_attn_mask = self.buffered_future_mask(x)
            else:
                self_attn_mask = None

            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = torch.empty(1).uniform_()
            if not self.training or (dropout_probability > self.decoder_layerdrop):
                m = encoder_state
                for i in range(len(addlink_addList[idx][0])):
                    m = DecoderLayer_ModuleList[idx][0][i](m, EncoderLayer_list[addlink_addList[idx][0][i]])
                for i in range(len(addlink_addList[idx][1])):
                    x = DecoderLayer_ModuleList[idx][1][i](x, inner_states[addlink_addList[idx][1][i]])
                x, layer_attn, _ = layer(
                    x,
                    m,
                    encoder_out.encoder_padding_mask
                    if encoder_out is not None
                    else None,
                    incremental_state,
                    self_attn_mask=self_attn_mask,
                    self_attn_padding_mask=self_attn_padding_mask,
                    need_attn=bool((idx == alignment_layer)),
                    need_head_weights=bool((idx == alignment_layer)),
                )
                inner_states.append(x)
                if layer_attn is not None and idx == alignment_layer:
                    attn = layer_attn.float().to(x)

        if attn is not None:
            if alignment_heads is not None:
                attn = attn[:alignment_heads]

            # average probabilities over heads
            attn = attn.mean(dim=0)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        return x, {"attn": [attn], "inner_states": inner_states}














def Append_transformer_forward(
        self,
        src_tokens,
        src_lengths,
        prev_output_tokens,
        Encoder_List= None,Encoder_addList= None, Decoder_List= None, decoder_addList= None,
        cls_input: Optional[Tensor] = None,
        return_all_hiddens: bool = True,
        features_only: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
    ):
        """
        Run the forward pass for an encoder-decoder model.

        Copied from the base class, but without ``**kwargs``,
        which are not supported by TorchScript.
        """
        Encoder_List = self.encoder_modellist
        Encoder_addList= self.encoder_addList
        Decoder_List= self.decoder_modellist
        decoder_addList= self.decoder_addList

        encoder_out = self.encoder(
            src_tokens,
            Encoder_List=Encoder_List, Encoder_addList=Encoder_addList,
            src_lengths=src_lengths,
            cls_input=cls_input,
            return_all_hiddens=return_all_hiddens,
        )
        decoder_out = self.decoder(
            prev_output_tokens,
            #EncoderLayer_list=encoder_out.encoder_states,
            DecoderLayer_ModuleList=Decoder_List, addlink_addList=decoder_addList,
            encoder_out=encoder_out,
            features_only=features_only,
            alignment_layer=alignment_layer,
            alignment_heads=alignment_heads,
            src_lengths=src_lengths,
            return_all_hiddens=return_all_hiddens,
        )
        return decoder_out


























def Re_train(self, mode=True):
        r"""Sets the module in training mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
        etc.

        Args:
            mode (bool): whether to set training mode (``True``) or evaluation
                         mode (``False``). Default: ``True``.

        Returns:
            Module: self
        """
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self


def Append_retrain(TransformerModel,Re_train=None):
    TransformerModel.train = MethodType(Re_train, TransformerModel)
    return TransformerModel














def add_parameters(TransformerModel, Encoder_List,Encoder_addList, Decoder_List, decoder_addList):
    TransformerModel.encoder_modellist=Encoder_List
    TransformerModel.decoder_modellist = Decoder_List
    TransformerModel.encoder_addList = Encoder_addList
    TransformerModel.decoder_addList = decoder_addList



    TransformerModel.encoder.encoder_modellist=Encoder_List
    TransformerModel.encoder.encoder_addList = Encoder_addList
    TransformerModel.decoder.decoder_modellist = Decoder_List
    TransformerModel.decoder.decoder_addList = decoder_addList


    return TransformerModel











class addlink(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """

    def __init__(self, mul=1e-7,grad_mul=1):#0.0001#1e-5
        super(addlink, self).__init__()
        self.par = nn.Parameter(Variable(torch.ones(1), requires_grad = True)*mul)
        #self.par=self.par*mul
        self.normal=False
        self.slow = False
        self.grad_nor = nn.Parameter(Variable(torch.ones(1), requires_grad=False))
        self.grad_mul=nn.Parameter(Variable(torch.ones(1), requires_grad=False)*60*grad_mul)#1000#100
        self.grad_slow = nn.Parameter(Variable(torch.ones(1), requires_grad=False) * 0.01)


    def forward(self, source, add):
        if (self.normal == True):
            par=f_Grad_switch(self.par, self.grad_nor)
        else:
            if (self.slow == True):
                par = f_Grad_switch(self.par, self.grad_slow)
            else:
                par = f_Grad_switch(self.par, self.grad_mul)

        #par=f_Grad_switch(self.par, self.normal, self.grad_mul)
        "Take in and process masked src and target sequences."
        return source+par*add
    def change_normal(self,Bo_=True):#？？
        self.normal = Bo_
    def start_slow(self,Bo_=True):#？？
        self.slow = Bo_

    def change_grad_mul(self,grad_mul):
        self.grad_mul=(Variable(torch.ones(1), requires_grad=False)*grad_mul).to(self.grad_mul.device)


def ini_model_list( coder_num, layer_num):#decodcer:layer_num=2------encodcer:layer_num=1
    a=nn.ModuleList([])
    b=nn.ModuleList([copy.deepcopy(a) for _ in range(layer_num)])
    c=[[] for _ in range(layer_num)]

    layer_ModuleList=nn.ModuleList([copy.deepcopy(b) for _ in range(coder_num)])
    addlink_addList=[copy.deepcopy(c) for _ in range(coder_num)]

    return layer_ModuleList,addlink_addList#layer_ModuleList-- addlink---Encoder_List

    #coder_num=len(TransformerModel.encoder.layers)+1

def make_new_connections(Encoder_List,Encoder_addList, Decoder_List, decoder_addList,mul=1,grad_mul=1):

    for p in range(len(Encoder_addList)):
        if (Encoder_addList[p]!=[[]]):
            for j in range(len(Encoder_addList[p][0])):
                Encoder_List[p][0].append(addlink(mul=mul,grad_mul=grad_mul))


    for p in range(len(decoder_addList)):
        if (decoder_addList[p]!=[[]]):
            for j in range(len(decoder_addList[p][0])):
                Decoder_List[p][0].append(addlink(mul=mul,grad_mul=grad_mul))
            for j in range(len(decoder_addList[p][1])):
                Decoder_List[p][1].append(addlink(mul=mul,grad_mul=grad_mul))



    return Encoder_List,Encoder_addList, Decoder_List, decoder_addList



def new_connection_random(Encoder_List,Encoder_addList, Decoder_List, decoder_addList,E_num,D_num,total_num=5):
    new_E_list = []
    new_D_list=[]

    a1=random.randint(0, total_num)
    a2=random.randint(0, total_num)
    a3=total_num#random.randint(0, total_num)
    min_ED=min(a1,a3,a3)#min(a1,a2,a3)
    for p in range(min_ED):#Decoder_List[][0][]---memory
        c_ED=random.randint(0, D_num-1)

        En_list_total = E_num +1
        c_E=random.randint(0, En_list_total-1)

        Decoder_List[c_ED][0].append(addlink())
        decoder_addList[c_ED][0].append(c_E)

        #new_D_list.append(Decoder_List[c_ED][0][-1])
        new_D_list.append((c_ED,0,len(Decoder_List[c_ED][0])-1))






    for p in range(total_num-min_ED):
        if(random.random()>0.5):
            c_E_R = random.randint(0, E_num - 1)
            #c_E_E=random.randint(0, 2 - 1)

            #Enlist_befor_max_ = c_E_R * (2) + (c_E_E)?+1
            Enlist_befor_max_ = c_E_R+1
            if (Enlist_befor_max_>0):
                c_E_S = random.randint(0, Enlist_befor_max_ - 1)

                Encoder_List[c_E_R][0].append(addlink())
                Encoder_addList[c_E_R][0].append(c_E_S)
                #new_E_list.append(Encoder_List[c_E_R][c_E_E][-1])
                new_E_list.append((c_E_R,0,len(Encoder_List[c_E_R][0])-1))

        else:
            c_D_R = random.randint(0, E_num - 1)
            #c_D_E = random.randint(0, 3 - 1)+1
            #Delist_befor_max_ = c_D_R * (3) + (c_D_E)?+1
            Delist_befor_max_ = c_D_R+1

            if (Delist_befor_max_>0):
                c_D_S = random.randint(0, Delist_befor_max_ - 1)

                Decoder_List[c_D_R][1].append(addlink())
                decoder_addList[c_D_R][1].append(c_D_S)

                new_D_list.append((c_D_R,1,len(Decoder_List[c_D_R][1])-1))

    return Encoder_List,Encoder_addList, Decoder_List, decoder_addList,new_E_list,new_D_list


#
# def load_add_model(model_list,new_list):
#     new_add_model_list=[]
#     for tup in new_list:
#         new_add_model_list.append(model_list[tup[0]][tup[1]][tup[2]])
#
#     return new_add_model_list








# def tmp_para(new_E_list,lrm0=1,para_list=[]):#decodcer:layer_num=4------encodcer:layer_num=2
#
#     para_list_temp=[]
#     for para in para_list:
#         para_list_temp.append(para)
#
#     for model_ in new_E_list:
#         para_list_temp.append({'params': model_.classifier.parameters(), 'lrm': lrm0*100})#lr0*1000
#
#
#
#
#
#
#     return para_list_temp
#
#
def return_normal(_model_list):#decodcer:layer_num=4------encodcer:layer_num=2##To all 2 model list


    for models in _model_list:
        for model in models:
            for model__ in model:
                model__.change_normal()


def set_slow(_model_list):#decodcer:layer_num=4------encodcer:layer_num=2##To all 2 model list


    for models in _model_list:
        for model in models:
            for model__ in model:
                model__.start_slow()




from fairseq.file_io import PathManager
import os
from fairseq.checkpoint_utils import _upgrade_state_dict

from torch.serialization import default_restore_location


def load_model_ensemble(filenames, arg_overrides=None, task=None,modelO=None):
    """Loads an ensemble of models.

    Args:
        filenames (List[str]): checkpoint files to load
        arg_overrides (Dict[str,Any], optional): override model args that
            were used during model training
        task (fairseq.tasks.FairseqTask, optional): task to use for loading
    """
    ensemble, args, _task = load_model_ensemble_and_task(filenames, arg_overrides, task,modelO)
    return ensemble, args


def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None,modelO=None):
    from fairseq import tasks

    ensemble = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError("Model file not found: {}".format(filename))
        state = load_checkpoint_to_cpu(filename, arg_overrides)

        args = state["args"]
        if task is None:
            task = tasks.setup_task(args)

        # build model for ensemble
        #model = task.build_model(args)
        model = modelO
        model.load_state_dict(state["model"], strict=True, args=args)
        ensemble.append(model)
    return ensemble, args, task


def load_checkpoint_to_cpu(path, arg_overrides=None):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
    with PathManager.open(path, "rb") as f:
        state = torch.load(
            f, map_location=lambda s, l: default_restore_location(s, "cpu")
        )

    args = state["args"]
    if arg_overrides is not None:
        for arg_name, arg_val in arg_overrides.items():
            setattr(args, arg_name, arg_val)
    state = _upgrade_state_dict(state)
    return state


def Del_append(Model_):
    del Model_.encoder.forward
    del Model_.decoder.extract_features
    del Model_.decoder.forward
    del Model_.train
    del Model_.forward

    Encoder_List=Model_.encoder_modellist
    Decoder_List=Model_.decoder_modellist
    Encoder_addList=Model_.encoder_addList
    decoder_addList=Model_.decoder_addList

    # Encoder_List=Model_.encoder.encoder_modellist
    # Encoder_addList=Model_.encoder.encoder_addList
    # Decoder_List=Model_.decoder.decoder_modellist
    # decoder_addList=Model_.decoder.decoder_addList
    del Model_.encoder_modellist
    del Model_.decoder_modellist
    del Model_.encoder_addList
    del Model_.decoder_addList
    del Model_.encoder.encoder_modellist
    del Model_.encoder.encoder_addList
    del Model_.decoder.decoder_modellist
    del Model_.decoder.decoder_addList




    return Encoder_List,Encoder_addList,Decoder_List,decoder_addList