from transformers.models.llama.modeling_llama import LlamaMLP , LlamaDecoderLayer , LlamaModel, LlamaForCausalLM
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import math
import warnings
from typing import List, Optional, Tuple, Union
from transformers.models.llama.configuration_llama import LlamaConfig

class taskLlamaMLP(LlamaMLP):
    def forward(self, hidden_states: torch.Tensor, expert_weight:torch.Tensor):
        
        down_proj = self.down_proj(self.act_fn(self.gate_proj(hidden_states,expert_weight)) * self.up_proj(hidden_states,expert_weight),expert_weight)

        return down_proj
class taskLlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__(config, layer_idx)
        self.mlp = taskLlamaMLP(config)
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            ood_emb (batch, emb_dim)
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states,self.expert_weight)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


class taskLlamaModel(LlamaModel):
    def __init__(self, config):
        super().__init__(config)
        
        self.layers = nn.ModuleList(
            [taskLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
       
        self.post_init()
class taskLlamaForCausalLM(LlamaForCausalLM):
   
    def __init__(self, config):
        super().__init__(config)
        self.model = taskLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()
class Gate(nn.Module):
    """
    simplest softmax router
    """
    def __init__(self, input_size, expert_num):

        super().__init__()
        # 使用embedding来代替线性层
        self.GateL = nn.Linear(input_size, expert_num, bias=False)
        self.act = nn.Softmax(dim=1)    # 第0维为batch size
    
    def forward(self,  x: torch.Tensor,ood_emb:Optional[torch.Tensor] = None):

        y = self.GateL(x)
        y = self.act(y)

        return y

class moetaskLlama(nn.Module):
    def __init__(self, llm, expert_num):
        super().__init__()
        self.llm=llm
      
        self.gate = Gate(768,expert_num)
    def forward(self,input_ids,task_emb,**kwargs):
        expert_weight = self.gate(task_emb)
        for decoder_layer in self.llm.model.layers[: self.llm.model.config.num_hidden_layers]:
            decoder_layer.expert_weight = expert_weight
        
        return self.llm(input_ids,**kwargs)
