from transformers import GPT2Model
import torch
from torch import nn
import torch.nn.functional as F
from typing import Optional, Tuple
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, Conv1D, ACT2FN
from transformers.models.gptj.modeling_gptj import GPTJMLP
from transformers.models.llama.modeling_llama import LlamaMLP

def cos_sim(x, y):
    dot_product = x @ y.T
    x_norm = torch.norm(x, p=2, dim=-1, keepdim=True)
    y_norm = torch.norm(y, p=2, dim=-1, keepdim=True)
    return dot_product / (x_norm * y_norm.T)
    
# 自定义修改后的MLP类

# class GPT2MLP(nn.Module):
#     def __init__(self, intermediate_size, config):
#         super().__init__()
#         embed_dim = config.hidden_size
#         self.c_fc = Conv1D(intermediate_size, embed_dim)
#         self.c_proj = Conv1D(embed_dim, intermediate_size)
#         self.act = ACT2FN[config.activation_function]
#         self.dropout = nn.Dropout(config.resid_pdrop)

#     def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
#         hidden_states = self.c_fc(hidden_states)
#         hidden_states = self.act(hidden_states)
#         hidden_states = self.c_proj(hidden_states)
#         hidden_states = self.dropout(hidden_states)
#         return hidden_states

class CustomGPT2MLP(GPT2MLP):
    def __init__(self, original_mlp, config):
        # 正确继承父类初始化
        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
        super().__init__(
            intermediate_size=inner_dim,
            config=config
        )
        
        # 复用原始参数（关键步骤）
        self.load_state_dict(original_mlp.state_dict())
        self.to(original_mlp.c_fc.weight.device)
        
        # 新增可训练参数（正确注册方式）
        self.edited = True
        self.K  =  None #[number_facts, hidden_k]
        self.V = None #[number_facts, hidden_v]

    def update_KV(self, K, V):
        if self.K == None and self.V == None:
            self.K, self.V = K, V
        elif self.K != None and self.V != None:
            self.K = torch.cat([self.K, K], dim=0)
            self.V = torch.cat([self.V, V], dim=0)

    # @torch.compile
    def compute_delta(self, h):
        assert self.K != None and self.V != None
        sim = cos_sim(h, self.K)
        
        indices = torch.zeros_like(sim, device=sim.device)
        max_sims, max_indices = torch.max(sim, dim=-1)
        indices.scatter_(-1, max_indices.unsqueeze(-1), 1.0)
        gate_in =  (sim.max(-1)[0]> 0.65)
        gate = gate_in.unsqueeze(-1) * indices
        # gate = torch.where(sim > 0.65, 1.0, 0.0)

        delta = torch.matmul(gate, self.V)
        return delta
        
    def forward(self, hidden_states):
        hidden_states = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        # hidden_states = self.c_proj(hidden_states)
        hidden_states = self.c_proj(hidden_states) + self.compute_delta(hidden_states)

        hidden_states = self.dropout(hidden_states)
        # original_output = self.c_proj(hidden_states) + self.compute_delta(hidden_states)


        return hidden_states

# class GPTJMLP(nn.Module):
#     def __init__(self, intermediate_size, config):  # in MLP: intermediate_size= 4 * embed_dim
#         super().__init__()
#         embed_dim = config.n_embd

#         self.fc_in = nn.Linear(embed_dim, intermediate_size)
#         self.fc_out = nn.Linear(intermediate_size, embed_dim)

#         self.act = ACT2FN[config.activation_function]
#         self.dropout = nn.Dropout(config.resid_pdrop)

#     def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
#         hidden_states = self.fc_in(hidden_states)
#         hidden_states = self.act(hidden_states)
#         hidden_states = self.fc_out(hidden_states)
#         hidden_states = self.dropout(hidden_states)
#         return hidden_states

class CustomGPTJMLP(GPTJMLP):
    def __init__(self, original_mlp, config):  # in MLP: intermediate_size= 4 * embed_dim
        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
        super().__init__(
            intermediate_size=inner_dim,
            config=config
        )
        
        # 复用原始参数（关键步骤）
        self.load_state_dict(original_mlp.state_dict())
        self.to(original_mlp.fc_out.weight.device)
        
        # 新增可训练参数（正确注册方式）
        self.edited = True
        self.K  =  None #[number_facts, hidden_k]
        self.V = None #[number_facts, hidden_v]


    def update_KV(self, K, V):
        if self.K == None and self.V == None:
            self.K, self.V = K, V
        elif self.K != None and self.V != None:
            self.K = torch.cat([self.K, K], dim=0)
            self.V = torch.cat([self.V, V], dim=0)

    # @torch.compile
    def compute_delta(self, h):
        assert self.K != None and self.V != None
        sim = cos_sim(h, self.K)
        
        indices = torch.zeros_like(sim, device=sim.device)
        max_sims, max_indices = torch.max(sim, dim=-1)
        indices.scatter_(-1, max_indices.unsqueeze(-1), 1.0)
        gate_in =  (sim.max(-1)[0]> 0.65)
        gate = gate_in.unsqueeze(-1) * indices
        # gate = torch.where(sim > 0.65, 1.0, 0.0)

        delta = torch.matmul(gate, self.V)
        return delta
        
    def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
        hidden_states = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        # hidden_states = self.fc_out(hidden_states) 
        hidden_states = self.fc_out(hidden_states) + self.compute_delta(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states


# class LlamaMLP(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.hidden_size = config.hidden_size
#         self.intermediate_size = config.intermediate_size
#         self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
#         self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
#         self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
#         self.act_fn = ACT2FN[config.hidden_act]

#     def forward(self, x):
#         down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
#         return down_proj

class CustomLlamaMLP(LlamaMLP):
    def __init__(self, original_mlp, config):
        super().__init__(config)
        self.load_state_dict(original_mlp.state_dict())
        self.to(original_mlp.down_proj.weight.device)
        
        # 新增可训练参数（正确注册方式）
        self.edited = True
        self.K  =  None #[number_facts, hidden_k]
        self.V = None #[number_facts, hidden_v]

    def update_KV(self, K, V):
        if self.K == None and self.V == None:
            self.K, self.V = K, V
        elif self.K != None and self.V != None:
            self.K = torch.cat([self.K, K], dim=0)
            self.V = torch.cat([self.V, V], dim=0)

    # @torch.compile
    def compute_delta(self, h):
        assert self.K != None and self.V != None
        sim = cos_sim(h, self.K)
        
        indices = torch.zeros_like(sim, device=sim.device)
        max_sims, max_indices = torch.max(sim, dim=-1)
        indices.scatter_(-1, max_indices.unsqueeze(-1), 1.0)
        gate_in =  (sim.max(-1)[0]> 0.65)
        gate = gate_in.unsqueeze(-1) * indices
        # gate = torch.where(sim > 0.65, 1.0, 0.0)
        delta = torch.matmul(gate, self.V)

        return delta

    def forward(self, x):
        h = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        down_proj = self.down_proj(h) +  self.compute_delta(h)
        return down_proj