import torch

from typing import Optional
from torch import nn
from tqdm import tqdm
from logger import logger
from gllm.layers.activation import SiluAndMul
from gllm.layers.rotary_embedding import RotaryEmbedding
from gllm.layers.attention import FlashAttention
from gllm.layers.layernorm import RMSNorm
from gllm.layers.sampler import Sampler
from gllm.input_data import InputData
from gllm.dist_utils import get_pp_load_num_layers,get_pp_used_num_layers,get_pp_load_layers,get_pp_used_layers,get_pp_size, get_pp_rank,set_pp_adjust_layers,get_pp_adjust_layers,get_pp_load_layers_native


class Qwen2MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2, 
                                     bias=False, dtype=config.torch_dtype, device='cuda')
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size,
                                   bias=False, dtype=config.torch_dtype, device='cuda')
        self.act_fn = SiluAndMul()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_up_proj(x)))


class Qwen2Attention(nn.Module):
    def __init__(self, layer_id: int, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_kv_heads = config.num_key_value_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = getattr(config,'rope_theta',10000)

        self.qkv_proj = nn.Linear(
            self.hidden_size, (self.num_heads+self.num_kv_heads*2)*self.head_dim, bias=True, dtype=config.torch_dtype, device='cuda')
        self.o_proj = nn.Linear(self.num_heads*self.head_dim, self.hidden_size,
                                bias=False, dtype=config.torch_dtype, device='cuda')
        self.rotary_emb = RotaryEmbedding(
            self.head_dim, self.head_dim, config.max_position_embeddings, self.rope_theta, True, config.torch_dtype)
        self.attn = FlashAttention(
            layer_id, self.scaling, self.num_heads, self.num_kv_heads, self.head_dim, self.hidden_size)

    def forward(self, input_data: InputData, hidden_states: torch.Tensor):
        qkv = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(input_data.positions, q, k)
        attn_output = self.attn.forward(q, k, v, input_data)
        output = self.o_proj(attn_output)
        return output


class Qwen2DecoderLayer(nn.Module):
    def __init__(self, layer_id: int, config):
        super().__init__()
        self.self_attn = Qwen2Attention(layer_id, config)
        self.mlp = Qwen2MLP(config)
        self.input_layernorm = RMSNorm(
            config.hidden_size, config.rms_norm_eps, config.torch_dtype)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, config.rms_norm_eps, config.torch_dtype)

    def forward(self, input_data: InputData, hidden_states: torch.Tensor, residual: Optional[torch.Tensor]):
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(input_data, hidden_states)

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class Qwen2Model(nn.Module):
    def __init__(self, config,enable_adjust_layers=False):
        super().__init__()
        if get_pp_rank() == 0 or config.tie_word_embeddings and get_pp_rank() == get_pp_size() - 1:
            self.embed_tokens = nn.Embedding(
                config.vocab_size, config.hidden_size, dtype=config.torch_dtype, device='cuda')
        self.enable_adjust_layers = enable_adjust_layers
        # 调整的层数（除最后一层外的层数）
        set_pp_adjust_layers(0)    

        # TODO 动态修改layer
        # 当前用到的layer层
        if self.enable_adjust_layers:
            self.load_start_layer,self.load_end_layer = get_pp_load_layers(
                config.num_hidden_layers)
            self.used_start_layer,self.used_end_layer = get_pp_used_layers(config.num_hidden_layers)
        else:
            self.load_start_layer,self.load_end_layer = get_pp_load_layers_native(config.num_hidden_layers)
            self.used_start_layer,self.used_end_layer = self.load_start_layer,self.load_end_layer
            
        self.num_hidden_layers = config.num_hidden_layers
        # TODO 也是动态修改layer
        self.layers = nn.ModuleList([
            Qwen2DecoderLayer(i-self.load_start_layer, config)
            for i in range(self.load_start_layer, self.load_end_layer)
        ])

        if get_pp_rank() == get_pp_size() - 1:
            self.norm = RMSNorm(
                config.hidden_size, config.rms_norm_eps, config.torch_dtype)

    def forward(self, input_data: InputData, hidden_states=None, residual=None):
        if get_pp_rank() == 0:
            hidden_states = self.embed_tokens(input_data.tokens)
            # logger.info(f"Worker 0 forward {self.used_end_layer - self.used_start_layer} layers")
        for i in range(self.used_start_layer-self.load_start_layer,self.used_end_layer-self.load_start_layer):
            layer = self.layers[i]
            hidden_states, residual = layer(
                input_data, hidden_states, residual)
            
        if get_pp_rank() == get_pp_size() - 1:
            hidden_states, _ = self.norm(hidden_states, residual)
            return hidden_states
        else:
            return hidden_states, residual
        
    def adjust_layer(self,adjust_layer):
        if adjust_layer == get_pp_adjust_layers():
            print('Current state is already the target state, no adjustment needed')
        else:
            set_pp_adjust_layers(adjust_layer)
            self.used_start_layer,self.used_end_layer = get_pp_used_layers(self.num_hidden_layers)


class Qwen2ForCausalLM(nn.Module):
    def __init__(self, config,enable_adjust_layers=False):
        super().__init__()
        self.config = config
        self.max_model_len = config.max_position_embeddings
        self.num_layers = get_pp_load_num_layers(config.num_hidden_layers) # 实际加载的layer数目
        self.dtype = config.torch_dtype
        self.num_kv_heads = config.num_key_value_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.model = Qwen2Model(config,enable_adjust_layers)
        self.ret_residual = True
        if get_pp_rank() == get_pp_size() - 1:
            if config.tie_word_embeddings:
                self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, 
                                         dtype=config.torch_dtype, device='cuda', bias=False)
                self.lm_head.weight = self.model.embed_tokens.weight
            else:
                self.lm_head = nn.Linear(
                    config.hidden_size, config.vocab_size, 
                    bias=False, dtype=config.torch_dtype, device='cuda')
        self.sampler = Sampler()

    def get_num_avg_layers(self):
        return self.config.num_hidden_layers // get_pp_size()
    
    def get_adjust_layers(self): # 获取调整的层数
        return get_pp_adjust_layers()
    
    def forward(self, input_data: InputData, hidden_states=None, residual=None):
        return self.model(input_data, hidden_states, residual)

    def compute_logits(self, input_data: InputData, hidden_states: torch.Tensor):
        # fetch hidden_states of last token in each seq
        idx_list = input_data.query_start_loc - 1
        return self.lm_head(hidden_states[idx_list[1:]])

    def sample(self, input_data: InputData, logits: torch.Tensor):
        return self.sampler.forward(logits, input_data)

    def load_weights(self, weights):
        parameters = dict(self.named_parameters())

        # assert len(parameters) == len(weights)
        num_attn_heads = self.config.num_attention_heads
        head_dim = self.config.hidden_size // num_attn_heads
        num_kv_heads = self.config.num_key_value_heads
        intermediate_size = self.config.intermediate_size
        for k, v in tqdm(parameters.items()):
            # resolve PP layer
            if 'layers' in k:
                k_list = k.split('.')
                k_list[2] = str(int(k_list[2])+self.model.load_start_layer)
                k = '.'.join(k_list)
            if k.find('self_attn.qkv_proj.weight') != -1:
                v.data[:num_attn_heads*head_dim, :] = weights[k.replace(
                    'qkv_proj', 'q_proj')]
                v.data[num_attn_heads*head_dim:(num_attn_heads +
                       num_kv_heads)*head_dim, :] = weights[k.replace('qkv_proj', 'k_proj')]
                v.data[(num_attn_heads +
                       num_kv_heads)*head_dim:, :] = weights[k.replace('qkv_proj', 'v_proj')]
            elif k.find('self_attn.qkv_proj.bias') != -1:
                v.data[:num_attn_heads*head_dim] = weights[k.replace(
                    'qkv_proj', 'q_proj')]
                v.data[num_attn_heads*head_dim:(num_attn_heads +
                       num_kv_heads)*head_dim] = weights[k.replace('qkv_proj', 'k_proj')]
                v.data[(num_attn_heads +
                       num_kv_heads)*head_dim:] = weights[k.replace('qkv_proj', 'v_proj')]
            elif k.find('gate_up_proj') != -1:
                v.data[:intermediate_size, :] = weights[k.replace(
                    'gate_up_proj', 'gate_proj')]
                v.data[intermediate_size:, :] = weights[k.replace(
                    'gate_up_proj', 'up_proj')]
            else:
                v.data.copy_(weights[k])

    def adjust_layer(self,adjust_layer):
        self.model.adjust_layer(adjust_layer)