import torch
from torch import nn
import numpy as np

from .GPT import attn_mask

def exponential_kaiming_normal_(
    tensor,
    exponential_rate: float = 0.5,
    mean = 0,
    fan = None,
):
    r"""
    魔改版 kaiming_normal_ 函数，用于初始化模型参数
    """
    fan = tensor.size(1) if fan is None else fan
    std = 1 / (fan ** exponential_rate)
    with torch.no_grad():
        tensor.normal_(mean, std, generator=None)
        return std
    

class MultiHeadAttention(nn.Module):
    def __init__(self, args):
        super(MultiHeadAttention, self).__init__()

        self.n_head = args.n_heads
        self.d_k = args.d_k
        self.d_v = args.d_v
        self.d_model = args.d_model
    
        self.W_Q = nn.Linear(args.d_model, args.d_k * args.n_heads, bias=False)
        self.W_K = nn.Linear(args.d_model, args.d_k * args.n_heads, bias=False)
        self.W_V = nn.Linear(args.d_model, args.d_v * args.n_heads, bias=False)
        self.fc = nn.Linear(args.n_heads * args.d_v, args.d_model, bias=False)

    def forward(self, input_Q, input_K, input_V, attn_mask, residual):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        batch_size = input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(input_Q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, self.n_head, self.d_v).transpose(1, 2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_head, 1, 1)  # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # scores : [batch_size, n_heads, len_q, len_k]
        attn = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
        masked_attn = attn.masked_fill(attn_mask, -1e9)  # Fills elements of self tensor with value where mask is True.
        softmax_attn = nn.Softmax(dim=-1)(masked_attn)
        qkv = torch.matmul(softmax_attn, V)  # [batch_size, n_heads, len_q, d_v] 

        qkv = qkv.transpose(1, 2).reshape(batch_size, -1, self.n_head * self.d_v)  # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(qkv)  # [batch_size, len_q, d_model]

        return output + residual, softmax_attn


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, args):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(args.d_model, args.d_feedforward, bias=False),
            nn.GELU(),
            nn.Linear(args.d_feedforward, args.d_model, bias=False)
        )

    def forward(self, hidden_state, residual):
        '''
        hidden_state: [batch_size, seq_len, d_model]
        '''
        output = self.fc(hidden_state)
        return output + residual # [batch_size, seq_len, d_model]


class DecoderLayer(nn.Module):
    def __init__(self, args):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(args)
        self.pos_ffn = PoswiseFeedForwardNet(args)
        self.pre_attn_layernorm = nn.LayerNorm(args.d_model, eps = 1e-8)
        self.pre_ffn_layernorm = nn.LayerNorm(args.d_model, eps = 1e-8)


    def forward(self, hidden_state, dec_self_attn_mask):
        '''
            hidden_state: [batch_size, tgt_len, d_model]
            dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        '''
        # Attention层
        # hidden_state: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        residual = hidden_state
        hidden_state = self.pre_attn_layernorm(hidden_state)
        hidden_state, dec_self_attn = self.dec_self_attn(hidden_state, hidden_state, hidden_state, dec_self_attn_mask, residual)

        # 非线性层
        residual = hidden_state
        hidden_state = self.pre_ffn_layernorm(hidden_state)
        hidden_state = self.pos_ffn(hidden_state, residual)  # [batch_size, tgt_len, d_model]
        return hidden_state, dec_self_attn


class Decoder(nn.Module):
    def __init__(self, args, device):
        super(Decoder, self).__init__()
        self.device = device
        self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.n_layers)])
        self.output_layernorm = nn.LayerNorm(args.d_model, eps = 1e-10)

    def forward(self, hidden_state, dec_self_attn_mask):
        '''
            hidden_state: [batch_size, tgt_len]
        '''
        dec_self_attns = []
        for layer in self.layers:
            # hidden_state: [batch_size, tgt_len, d_model]
            # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
            hidden_state, dec_self_attn = layer(hidden_state, dec_self_attn_mask)
   
            dec_self_attns.append(dec_self_attn)
        hidden_state = self.output_layernorm(hidden_state)
        return hidden_state, dec_self_attns

class Embedding(nn.Module):
    def __init__ (self, args, device):
        super(Embedding, self).__init__()
        self.device = device
        self.tgt_emb = nn.Embedding(args.vocab_size , args.d_model)
        self.pos_emb = nn.Embedding(args.seq_len, args.d_model)

        if args.freeze_embedding:
            # 冻结 embedding 层
            for param in self.tgt_emb.parameters():
                param.requires_grad = False
            print(f'Embedding layer is frozen in {self.__class__.__name__}')

    def forward(self, X_input):
        seq_len = X_input.size(1)
        pos = torch.arange(seq_len, dtype = torch.long, device = self.device)
        pos = pos.unsqueeze(0).expand_as(X_input)

        tgt_emb = self.tgt_emb(X_input)
        pos_emb = self.pos_emb(pos)
        emb = tgt_emb + pos_emb

        return emb

class GPT_PreLN(nn.Module):
    def __init__(self, args, device, layernorm_weight_zeroinit = False):
        """
        params:
            args: args
            device: device
            all_std: bool, 是否使用kaiming初始化; 若为一个非负数，则使用kaiming初始化，且乘以这个数
        """
        super(GPT_PreLN, self).__init__()
        self.device = device
        self.embedding = Embedding(args, device)
        self.decoder = Decoder(args, device)
        self.projection = nn.Linear(args.d_model, args.vocab_size,bias=False)
        self.args = args
        self.layernorm = nn.LayerNorm(args.d_model)
        
        if args.all_std:
            all_std = args.all_std
            if all_std != -1:
                for name, p in self.named_parameters():

                    if p.dim() > 1 and (not 'tgt_emb' in name):
                        fan = p.size(1)
                        std = 1 / (fan ** all_std)
                        std = std / args.n_layers
                        print(f"{name} module initial std is {std}")
                        # exponential_kaiming_normal_(p, all_std)
                        nn.init.normal_(p, mean = 0.0, std = std)
                    elif 'layernorm' in name and layernorm_weight_zeroinit:
                        nn.init.constant_(p, 1e-5)
                        print(f'layernorm weight is initialized by zero')
                    elif 'tgt_emb' in name:
                        fan = p.size(1)
                        std = 1 / (fan ** all_std)
                        std = std / args.n_layers
                        nn.init.normal_(p, mean = 0.5, std = std)
                        # exponential_kaiming_normal_(p, all_std, mean = 0.5)
                        print(f'embedding weight is initialized by normal')
                    else:
                        print(f'The module {name} is not initialized by kaiming')
                # print(f'kaiming initialize: {self.__class__.__name__}; Exponential rate: {all_std}')
            else:
                # 使用正常的 kaiming 初始化
                for name, p in self.named_parameters():
                    if p.dim() > 1 and 'layernorm' not in name:
                        nn.init.kaiming_normal_(p)
                    elif 'layernorm' in name and layernorm_weight_zeroinit:
                        nn.init.constant_(p, 1e-8)
                        print(f'layernorm weight is initialized by zero')
                    else:
                        print(f'The module {name} is not initialized by kaiming')
                print(f'kaiming initialize: {self.__class__.__name__}')
        # exit()

    def forward(self, X_input, ):
        """
            dec_inputs: [batch_size, tgt_len]
        """
        hidden_state = self.embedding(X_input)

        dec_self_attn_mask = attn_mask(X_input, self.device, )

        hidden_state, dec_self_attns = self.decoder(hidden_state, dec_self_attn_mask)
        hidden_state = self.layernorm(hidden_state)
        dec_logits = self.projection(hidden_state)
        
        return dec_logits.view(-1, dec_logits.size(-1)), dec_self_attns
    

    def greedy_decoder(self,dec_input):

        projected, _ = self.forward(dec_input)

        projected = projected[-1,:].argmax()
        next_word = projected.item() 

        return next_word

    def generate(self, dec_input, max_len, temperature = 1.0, return_logits = False, generated_pad_mask = None):
        """
            dec_input: [batch_size, tgt_len]
        """
        dec_input = dec_input.to(self.device)
        generated = dec_input
        max_len = max(max_len, 9)
        if return_logits:
            logits = []
        while len(generated) < max_len:
            projected, _ = self.forward(generated.unsqueeze(0), generated_pad_mask)
            projected = projected[-1,:]
            projected = projected / temperature
            projected = nn.Softmax(dim=-1)(projected)
            new_word = projected.argmax()
            print(f'new_word\'s probility: {projected[new_word]}')
            generated = torch.cat((generated, new_word.unsqueeze(0)), dim = 0)
            if return_logits:
                logits.append(projected.detach().cpu().numpy())
        if return_logits:
            return generated, logits
        else:
            return generated

    
    def test(self,sentence):
        dec_input = torch.tensor(sentence, dtype=torch.long, device=self.device).unsqueeze(0)
        output = self.greedy_decoder(dec_input)
        return output


    def mini_init(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.normal_(p, mean = 0.0, std = 1e-3)
        print(f'mini initialize: {self.__class__.__name__}')
        
    def large_init(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.normal_(p, mean = 0.0, std = 1e-1)
        print(f'mini initialize: {self.__class__.__name__}')



