import re
from  .lora import LoraLayer
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from typing import Optional

def freeze_adapter(model):
    for name, param in model.named_parameters():
          # freeze base model's layers
        if "lora"  in name and "GateL" not in name:
            param.requires_grad = False
def topk_mask(expert_weight, k):
    """
    对 expert_weight 进行 Top-K 处理，未选中的位置置为 0。

    参数:
        expert_weight (torch.Tensor): 形状为 (batch * sequence_length, n_experts) 的张量。
        k (int): 保留的 Top-K 值。

    返回:
        torch.Tensor: 形状为 (batch * sequence_length, n_experts) 的张量，仅保留 Top-K 值，其余位置为 0。
    """
    # 获取每一行的 Top-K 值和索引
    topk_values, topk_indices = torch.topk(expert_weight, k, dim=-1)

    # 创建一个全零的张量，形状与 expert_weight 相同
    result = torch.zeros_like(expert_weight,device=expert_weight.device,dtype=expert_weight.dtype)

    # 将 Top-K 值填充到对应位置
    result.scatter_(dim=-1, index=topk_indices, src=topk_values)
    # keep normalize
    result /= result.sum(dim=-1, keepdim=True)
    return result
class Gate(nn.Module):
    """
    simplest softmax router
    """
    def __init__(self, input_size, expert_num):

        super().__init__()
        # 使用embedding来代替线性层
        self.GateL = nn.Linear(input_size, expert_num, bias=False)
        self.act = nn.Softmax(dim=1)    # 第0维为batch size
    
    def forward(self,  x: torch.Tensor,ood_emb:Optional[torch.Tensor] = None):

        y = self.GateL(x)
        y = self.act(y)

        return y
class moetaskLlama(nn.Module):
    def __init__(self, llm, expert_num):
        super().__init__()
        self.llm=llm
      
        self.gate = Gate(768,expert_num)
        self.gate.to(self.llm.device)
    def forward(self,**kwargs):
        # expert_weight: (batch , n_experts)
       
        input_ids = kwargs.get('input_ids')
        embedding=kwargs.pop('embeddings',None)
        expert_weight = self.gate(embedding)
        batch,n_experts = expert_weight.shape
        batch,sequence_length = input_ids.shape
        # expert_weight: (batch * sequence_length, n_experts)
        expert_weight = expert_weight.unsqueeze(1).repeat(1, sequence_length, 1).view(-1, n_experts)
       
        return self.llm(expert_weight=expert_weight,**kwargs)
# class cosineScorer(nn.Module):
#     def __init__(self, input_size, expert_num, proj_dim=256):
#         super().__init__()
#         init_t=0.5
#         self.temperature = torch.nn.Parameter(torch.log(torch.full([1], 1.0 / init_t)), requires_grad=True)
#         self.cosine_projector = torch.nn.Linear(input_size, proj_dim)
#         self.sim_matrix = torch.nn.Parameter(torch.randn(size=(proj_dim, expert_num)), requires_grad=True)
#         self.clamp_max = torch.log(torch.tensor(1. / 0.01)).item()
#         torch.nn.init.normal_(self.sim_matrix, 0, 0.01)
#     def forward(self,x:torch.Tensor):
#         logits = torch.matmul(F.normalize(self.cosine_projector(x), dim=1),
#                               F.normalize(self.sim_matrix, dim=0))
#         logit_scale = torch.clamp(self.temperature, max=self.clamp_max).exp()
#         logits = logits * logit_scale
#         return logits
# class Gate(nn.Module):
#     """
#     cosine list router
#     """
#     def __init__(self, input_size, expert_num):

#         super().__init__()
#         # 使用embedding来代替线性层
#         self.GateA = cosineScorer(input_size, expert_num)
#         self.sample_dim=768
#         self.GateB = cosineScorer(self.sample_dim, expert_num)
        
#         self.act = nn.Softmax(dim=1)    # 第0维为batch size        
#     def forward(self,  x: torch.Tensor,ood_emb:Optional[torch.Tensor] = None):
#         if ood_emb is not None:
#             y1 = self.GateA(x)
#             y2 = self.GateB(ood_emb)
#             logits = 0.9*y1+0.1*y2
#         else:
#             logits = self.GateA(x)
#         # TODO: add noise  logits_w_noise = logits + gctx.gate_noise * torch.randn_like(logits) / self.num_global_experts  
#         y = self.act(logits)
#         return y
class MOELoraLayer(bnb.nn.Linear8bitLt,LoraLayer):
    def __init__(
            self,
            adapter_name,
            in_features,
            out_features,
            r: int = 0,
            lora_alpha: int = 1,
            lora_dropout: float = 0.0,
            expert_num : int=4,
            cluster: bool = False,
            **kwargs,
        ):
        bnb.nn.Linear8bitLt.__init__(
                self,
                in_features,
                out_features,
                bias=kwargs.get("bias", True),
                has_fp16_weights=kwargs.get("has_fp16_weights", True),
                memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
                threshold=kwargs.get("threshold", 0.0),
                index=kwargs.get("index", None),
        )
        LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
        self.lora_gate = nn.ModuleDict({})
        self.cluster=cluster
        self.expert_num=expert_num
        self.active_adapter = adapter_name
        # if cluster:
        #     # choose expert based on cluster emb and share the gate layer, leave for the 
        #     pass
        # else:
        #     # choose expert based on token emb
        #     self.lora_gate.update(nn.ModuleDict({adapter_name: Gate(in_features, self.expert_num)}))
        self.weight.requires_grad = False
        init_lora_weights =True
        self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
        
        self.topk_ =2 
        
    def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha
        if lora_dropout > 0.0:
            lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            lora_dropout_layer = nn.Identity()

        # Actual trainable parameters
        
        if r > 0:
            for i in range(self.expert_num):
                self.lora_A.update(nn.ModuleDict({f"{adapter_name}_{i}":nn.Linear(self.in_features, r, bias=False)}))
                self.lora_B.update(nn.ModuleDict({f"{adapter_name}_{i}": nn.Linear(r, self.out_features, bias=False)}))
                self.lora_dropout.update(nn.ModuleDict({f"{adapter_name}_{i}": lora_dropout_layer}))

            self.scaling[adapter_name] = lora_alpha / r
        if init_lora_weights:
            self.reset_lora_parameters()
        self.to(self.weight.device)
    
    def reset_lora_parameters(self):
        adapter_name = self.active_adapter
            # initialize A the same way as the default for nn.Linear and B to zero
        for i in range(self.expert_num):
                nn.init.normal_(self.lora_A[f"{adapter_name}_{i}"].weight, mean=0.0, std=0.01)
                nn.init.zeros_(self.lora_B[f"{adapter_name}_{i}"].weight)

    def forward(self, x: torch.Tensor, expert_weight: torch.Tensor):
        """
        ood_emb should be broad cast to token level
        """
        adapter_name = self.active_adapter
        result = super().forward(x)
        result:torch.Tensor
        shape=result.shape
        expected_dtype = result.dtype
        if x.dtype != torch.float32:
            x = x.float()
        x=x.reshape(-1,self.in_features)
        # if  expert_weight is None :
        #     expert_weight = self.lora_gate[self.active_adapter](x)
       
        expert_weight = topk_mask(expert_weight,self.topk_)

        for i in range(self.expert_num): 
            output = (
                    self.lora_B[f"{adapter_name}_{i}"](
                        self.lora_A[f"{adapter_name}_{i}"](x)
                    ).to(expected_dtype)
                    * self.scaling[self.active_adapter] * expert_weight[..., i].unsqueeze(1)
                )
            output=output.reshape(shape[0],shape[1],shape[-1])
            result = result+output
        return result
class HyLoraLayer(bnb.nn.Linear8bitLt,LoraLayer):
    # HyLoraLayer, share A while different B
    def __init__(
            self,
            adapter_name,
            in_features,
            out_features,
            r: int = 0,
            lora_alpha: int = 1,
            lora_dropout: float = 0.0,
            expert_num : int=4,
            cluster: bool = False,
            **kwargs,
        ):
        bnb.nn.Linear8bitLt.__init__(
                self,
                in_features,
                out_features,
                bias=kwargs.get("bias", True),
                has_fp16_weights=kwargs.get("has_fp16_weights", True),
                memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
                threshold=kwargs.get("threshold", 0.0),
                index=kwargs.get("index", None),
        )
        LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
        self.lora_gate = nn.ModuleDict({})
        self.cluster=cluster
        self.expert_num=expert_num
        self.active_adapter = adapter_name
        # if cluster:
        #     # choose expert based on cluster emb and share the gate layer, leave for the 
        #     pass
        # else:
        #     # choose expert based on token emb
        #     self.lora_gate.update(nn.ModuleDict({adapter_name: Gate(in_features, self.expert_num)}))
        self.weight.requires_grad = False
        init_lora_weights =True
        self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
        
        self.topk_ =2 
        
    def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha
        if lora_dropout > 0.0:
            lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            lora_dropout_layer = nn.Identity()

        # Actual trainable parameters
        
        if r > 0:
            self.lora_A.update(nn.ModuleDict({f"{adapter_name}_":nn.Linear(self.in_features, r, bias=False)}))
            for i in range(self.expert_num):   
                self.lora_B.update(nn.ModuleDict({f"{adapter_name}_{i}": nn.Linear(r, self.out_features, bias=False)}))
                self.lora_dropout.update(nn.ModuleDict({f"{adapter_name}_{i}": lora_dropout_layer}))

            self.scaling[adapter_name] = lora_alpha / r
        if init_lora_weights:
            self.reset_lora_parameters()
        self.to(self.weight.device)
    
    def reset_lora_parameters(self):
        adapter_name = self.active_adapter
        nn.init.normal_(self.lora_A[f"{adapter_name}_"].weight, mean=0.0, std=0.01)
            # initialize A the same way as the default for nn.Linear and B to zero
        for i in range(self.expert_num):
                
                nn.init.zeros_(self.lora_B[f"{adapter_name}_{i}"].weight)

    def forward(self, x: torch.Tensor, expert_weight: torch.Tensor):
        """
        ood_emb should be broad cast to token level
        """
        adapter_name = self.active_adapter
        result = super().forward(x)
        result:torch.Tensor
        shape=result.shape
        expected_dtype = result.dtype
        if x.dtype != torch.float32:
            x = x.float()
        x=x.reshape(-1,self.in_features)
        # if  expert_weight is None :
        #     expert_weight = self.lora_gate[self.active_adapter](x)
       
        expert_weight = topk_mask(expert_weight,self.topk_)

        for i in range(self.expert_num): 
            output = (
                    self.lora_B[f"{adapter_name}_{i}"](
                        self.lora_A[f"{adapter_name}_"](x)
                    ).to(expected_dtype)
                    * self.scaling[self.active_adapter] * expert_weight[..., i].unsqueeze(1)
                )
            output=output.reshape(shape[0],shape[1],shape[-1])
            result = result+output
        return result


def _get_submodules(model, key):
    parent = model.get_submodule(".".join(key.split(".")[:-1]))
    target_name = key.split(".")[-1]
    target = model.get_submodule(key)
    return parent, target, target_name

def find_and_replace(model, adapter_name,lora_config,expert_num,cluster,layer_type="moe"):
        """Replace the target `Linear` module with LoRA layer (Linear+LoRA)"""
        model.expert_num = expert_num
        model.num_experts_per_tok = 2
        loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)
        is_target_modules_in_base_model = False
        if cluster:
            model.gate = nn.ModuleDict({})
        key_list = [key for key, _ in model.named_modules()]   # all module in raw model
        for key in key_list:
            # find the corresponding modules. target module has been split into list.
            if isinstance(lora_config.target_modules, str):
                target_module_found = re.fullmatch(lora_config.target_modules, key)
            else:
                target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
            if target_module_found:
                
                parent, target, target_name = _get_submodules(model, key)
                bias = target.bias is not None
              
                if isinstance(target, LoraLayer):
                    target.update_layer(
                        adapter_name,
                        lora_config.r,
                        lora_config.lora_alpha,
                        lora_config.lora_dropout,
                        lora_config.init_lora_weights,
                    )
                else:
                    if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
                        eightbit_kwargs = {}
                        eightbit_kwargs.update(
                            {
                                "has_fp16_weights": target.state.has_fp16_weights,
                                "memory_efficient_backward": target.state.memory_efficient_backward,
                                "threshold": target.state.threshold,
                                "index": target.index,
                            }
                        )
                       
                        in_features, out_features = target.in_features, target.out_features
                       
                        if layer_type=="hydra":
                             new_module = HyLoraLayer(adapter_name, in_features, out_features, lora_config.r,
                lora_config.lora_alpha,
                lora_config.lora_dropout,expert_num,cluster,bias=False,**eightbit_kwargs)
                        else :
                            new_module = MOELoraLayer(adapter_name, in_features, out_features, lora_config.r,
                lora_config.lora_alpha,
                lora_config.lora_dropout,expert_num,cluster,bias=False,**eightbit_kwargs)
                        _replace_module(parent, target_name, new_module, target)

        # if not is_target_modules_in_base_model:
        #     raise ValueError(
        #         f"Target modules {lora_config.target_modules} not found in the base model. "
        #         f"Please check the target modules and try again."
        #     )

def _replace_module(parent_module, child_name, new_module, old_module):
        setattr(parent_module, child_name, new_module)
        new_module.weight = old_module.weight
        if hasattr(old_module, "bias"):
            if old_module.bias is not None:
                new_module.bias = old_module.bias

        if getattr(old_module, "state", None) is not None:
            new_module.state = old_module.state
            new_module.to(old_module.weight.device)

        # dispatch to correct device
        for name, module in new_module.named_modules():
            if "lora_" in name:
                module.to(old_module.weight.device)

def set_lora_cluster(peft_model, expert_num):
    peft_model.gate=Gate(task_dim,expert_num)
    for module in peft_model.model.modules():
        if isinstance(module, MOELoraLayer): 
            module.lora_gate.update(nn.ModuleDict({module.active_adapter: peft_model.gate}))
