import torch,os
from moepeft import Gate,find_and_replace
from peft import LoraConfig
from transformers import Trainer
from torch.utils.data import DataLoader
from typing import Optional
from peft import (
    PeftModel, set_peft_model_state_dict
)
from datasets import Dataset
from dataclasses import dataclass
def get_split_indices(n, num_splits=4):
    if num_splits <= 0:
        return []
    
    base = n // num_splits
    remainder = n % num_splits
    layer_counts = [base + 1 if i < remainder else base for i in range(num_splits)]
    split_indices = []
    current = 0
    for count in layer_counts: 
        current += count
        split_indices.append(current-1)  
    
    return split_indices
def load_expert_group_weights(model,adapter_name,adpater_paths,expert_num_per_domain=4):
  peft_model_state_dict=model.state_dict()
  for i,adpater_path in enumerate(adpater_paths):
      adpater_state_dict=torch.load(adpater_path,map_location=model.device)  
      for k,v in adpater_state_dict.items():
          if "lora_" in k:
              suffix = k.split("lora_")[1]
              suffix_to_replace = ".".join(suffix.split(".")[1:])
              reserve=suffix_to_replace.split(".")[-1]
              part=int(suffix_to_replace.split(".")[0].split("_")[-1])
              part=part+i*expert_num_per_domain
              k = k.replace(suffix_to_replace, f"{adapter_name}_{part}.{reserve}")
              peft_model_state_dict[k] = v
          if "GateL" in k:
              suffix = k.split("mlp")[1]
              suffix_to_replace = ".".join(suffix.split(".")[1:])
              suffix_to_replace = suffix_to_replace.split("GateL")[0]
              k = k.replace(suffix_to_replace, f"gates.{i}.")
              peft_model_state_dict[k] =v
  model.load_state_dict(peft_model_state_dict)
  return model
def topk_mask(expert_weight, k):
    """
    对 expert_weight 进行 Top-K 处理，未选中的位置置为 0。

    参数:
        expert_weight (torch.Tensor): 形状为 (batch * sequence_length, n_experts) 的张量。
        k (int): 保留的 Top-K 值。

    返回:
        torch.Tensor: 形状为 (batch * sequence_length, n_experts) 的张量，仅保留 Top-K 值，其余位置为 0。
    """
   
    topk_values, topk_indices = torch.topk(expert_weight, k, dim=-1)
    result = torch.zeros_like(expert_weight,device=expert_weight.device,dtype=expert_weight.dtype)
    result.scatter_(dim=-1, index=topk_indices, src=topk_values)
    # keep normalize
    result /= result.sum(dim=-1, keepdim=True)
    return result
def add_mlp_domain_gates(model,expert_num_per_domain,domain_num):
    key_list = [key for key, _ in model.named_modules()]
    target_key="mlp"
    for key in key_list:
        target_module_found = key.endswith(target_key)
        if target_module_found:    
            target=model.get_submodule(key)
            target.gates= torch.nn.ModuleDict({})
            target.domain_num=domain_num
            target.get_divers_loss=False
            for i in range(domain_num):
                target.gates[str(i)]=Gate(model.config.hidden_size, expert_num_per_domain).to(target.gate_proj.weight.device)
                torch.nn.init.normal_(target.gates[str(i)].GateL.weight, mean=0, std=0.02)
            # no shared among layers version
            target.gates["domain"]=Gate(model.config.hidden_size, domain_num).to(target.gate_proj.weight.device)
            torch.nn.init.normal_(target.gates["domain"].GateL.weight, mean=0, std=0.02)
    return model
def add_mlp_gate(model,expert_num):
    key_list = [key for key, _ in model.named_modules()]
    target_key="mlp"
    for key in key_list:
        target_module_found = key.endswith(target_key)
        if target_module_found:
            target=model.get_submodule(key)
            target.gate=Gate(model.config.hidden_size, expert_num).to(model.device)
            target.get_divers_loss=False
            torch.nn.init.normal_(target.gate.GateL.weight, mean=0, std=0.02)
    return model
def add_mlp_sen_gate(model,domain_num):
    key_list = [key for key, _ in model.named_modules()]
    # target_diverse_loss_layer=[f"layers.{i}" for i in [9,19,29,39]]
    target_key="mlp"
    for key in key_list:
        target_module_found = key.endswith(target_key)
        if target_module_found:    
            target=model.get_submodule(key)
          
            target.domain_num=domain_num
           
            target.get_divers_loss=False
            target.sengate=Gate(model.config.hidden_size, domain_num).to(target.gate_proj.weight.device)
            torch.nn.init.normal_(target.sengate.GateL.weight, mean=0, std=0.02)
    return model
def load_moe_model(mode,ckpt_path,model,adapter_name,expert_num,layer_type="moe"):
    # set lora adapters
    adapter_path="/mnt/mednas/jiangyinuo.jyn/ckpt"
    lora_config=LoraConfig.from_pretrained(adapter_path)
    find_and_replace(model,adapter_name,lora_config,expert_num,False,layer_type=layer_type)
    model.num_experts=expert_num
    # set gates
    if mode=="coop":
        model=add_mlp_gate(model,expert_num)
    elif mode=="sen":
        model=add_mlp_sen_gate(model,expert_num)
    else:
        model=add_mlp_domain_gates(model,4,int(expert_num/4))
    # load
    adapter_state_dict_path=os.path.join(ckpt_path,"adapter.pth")
    adpater_state_dict=torch.load(adapter_state_dict_path,map_location=model.device)
    peft_model_state_dict=model.state_dict()
    for k, v in adpater_state_dict.items():
              peft_model_state_dict[k] = v
    model.load_state_dict(peft_model_state_dict)
    return model
### For training 

def resume_adapter(model,adapter_path):
    
    adapted_model = PeftModel.from_pretrained(model, adapter_path)
    adapter_state_dict_path=os.path.join(adapter_path,"adapter.pth")
    adpater_state_dict=torch.load(adapter_state_dict_path,map_location=model.device)
    adapter_name="default"
    set_peft_model_state_dict(adapted_model,adpater_state_dict,adapter_name)
    return adapted_model

class MyCustomTrainer(Trainer):
    def get_train_dataloader(self) -> DataLoader:
        train_dataset = self.train_dataset
        data_collator = self.data_collator
        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }
        return self.accelerator.prepare(DataLoader(train_dataset, shuffle=False, **dataloader_params))
    def get_eval_dataloader(self,eval_dataset: Optional[Dataset] = None) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        data_collator = self.data_collator

        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }
        return self.accelerator.prepare(DataLoader(eval_dataset, shuffle=False, **dataloader_params))
def freeze_model(model):
   for name, param in model.named_parameters():
          # freeze base model's layers
        if "lora" in name :
            param.requires_grad = True
        else:
            param.requires_grad = False
@dataclass
class MyDataCollator(object):
    """Collate examples for supervised fine-tuning."""

    def __call__(self, instances) :
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
    
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
           
        )