
# domain router use average hidden states not shared
import sys,os

from typing import List
from dataclasses import dataclass
from transformers import Trainer,TrainingArguments
from torch.utils.data import DataLoader
from datasets import Dataset
import torch
from datasets import load_from_disk
from datetime import datetime
import transformers
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""

from peft import (
    LoraConfig,
   
    prepare_model_for_int8_training,
   
)
from  moepeft import( 
    find_and_replace,
   Gate
    )

from typing import List
from prompter import format
from dataclasses import dataclass

import torch

from dataclasses import dataclass
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_callback import TrainerCallback
IGNORE_INDEX = -100
import torch.nn as nn
from typing import Optional
from utils import load_expert_group_weights,get_split_indices,MyCustomTrainer,freeze_model,MyDataCollator
def add_mlp_domain_gates(model,expert_num_per_domain,domain_num):
    key_list = [key for key, _ in model.named_modules()]
    target_diverse_loss_layer=[f"layers.{i}" for i in get_split_indices(model.config.num_hidden_layers)]
    target_key="mlp"
    for key in key_list:
        target_module_found = key.endswith(target_key)
        if target_module_found:    
            target=model.get_submodule(key)
            target.gates= torch.nn.ModuleDict({})
            target.domain_num=domain_num
            if any(sub in key for sub in target_diverse_loss_layer):
                target.get_divers_loss =True
            else:
                target.get_divers_loss=False
            for i in range(domain_num):
                target.gates[str(i)]=Gate(model.config.hidden_size, expert_num_per_domain).to(target.gate_proj.weight.device)
                torch.nn.init.normal_(target.gates[str(i)].GateL.weight, mean=0, std=0.02)
            # no shared among layers version
            target.gates["domain"]=Gate(model.config.hidden_size, domain_num).to(target.gate_proj.weight.device)
            torch.nn.init.normal_(target.gates["domain"].GateL.weight, mean=0, std=0.02)
    return model

def my_get_peft_model_state_dict(model, state_dict=None,):
    if state_dict is None:
        state_dict = model.state_dict()
   
    to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "GateL" in k}
   
    return to_return
def get_moe_model(model,expert_num,cluster,adapter_name):
    adapter_path="/mnt/mednas/jiangyinuo.jyn/HiMoLE/src/config"
    lora_config=LoraConfig.from_pretrained(adapter_path)
    find_and_replace(model,adapter_name,lora_config,expert_num,cluster)
    model.router_aux_loss_coef=0.001
    model.diverse_loss_coef=0.0001
    model.num_experts=expert_num # uesd in load_balancing_loss_func,supposed to be total
    return model

class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        new_state_dict = my_get_peft_model_state_dict(kwargs["model"])
        torch.save(new_state_dict,f"{checkpoint_folder}/adapter.pth")

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        os.remove(pytorch_model_path) if os.path.exists(pytorch_model_path) else None
        
def train(adapter_paths,
    # model/data params
    base_model: str = "",  # the only required argument
    output_dir: str = "./lora-alpaca",
   
    cluster: bool = False,
    expert_num_per_domain:int =4,
    # training hyperparams
    batch_size: int = 128,
    micro_batch_size: int = 4,
    num_epochs: int = 5,
    learning_rate: float = 3e-5,
    cutoff_len: int = 1024,
    val_set_size: int = 2000,
   
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
):

    adapter_paths=[os.path.join(adapter_dir,"adapter.pth") for adapter_dir in adapter_paths]
    gradient_accumulation_steps = batch_size // micro_batch_size
    domain_num = len(adapter_paths)
    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    # Check if parameter passed or if set within environ
    model = LlamaForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map=device_map,
    )
    tokenizer = LlamaTokenizer.from_pretrained(base_model)
    data_collator = MyDataCollator(tokenizer=tokenizer)
    def tokenize(data):
        sources=[s for s in data["sentence"]]
        targets=[s for s in data["output"]]
        tokenized_sources = tokenizer(sources,return_attention_mask=False)
        tokenized_targets = tokenizer(targets,return_attention_mask=False,add_special_tokens=False)
        all_input_ids = []
        all_labels = []
        for s,t in zip(tokenized_sources['input_ids'],tokenized_targets['input_ids']):
            t=t+[tokenizer.eos_token_id]
            if len(s + t)>cutoff_len and (s + t)[:cutoff_len][-1]!=tokenizer.eos_token_id:

                input_ids = torch.LongTensor((s + t)[:cutoff_len-1]+[tokenizer.eos_token_id])
                labels = torch.LongTensor(([IGNORE_INDEX] * len(s) + t)[:cutoff_len-1]+[tokenizer.eos_token_id])
            else:
                input_ids = torch.LongTensor(s + t)[:cutoff_len]
                labels = torch.LongTensor([IGNORE_INDEX] * len(s) + t)[:cutoff_len]
            assert len(input_ids) == len(labels)
            assert  input_ids[-1]==tokenizer.eos_token_id
            assert  labels[-1]==tokenizer.eos_token_id
            all_input_ids.append(input_ids)
            all_labels.append(labels)
        results = {'input_ids':all_input_ids, 'labels': all_labels}

        return results

    model = prepare_model_for_int8_training(model)
    expert_num =len(adapter_paths)*expert_num_per_domain
    adapter_name="id"
    model=get_moe_model(model,expert_num,cluster,adapter_name)
    # freeze other part
    freeze_model(model)
    model=add_mlp_domain_gates(model,expert_num_per_domain,len(adapter_paths))
    # init weight
    model=load_expert_group_weights(model,adapter_name,adapter_paths,expert_num_per_domain)
    if resume_from_checkpoint is not None:
        adapter_state_dict_path=os.path.join(resume_from_checkpoint,"adapter.pth")
        adpater_state_dict=torch.load(adapter_state_dict_path,map_location=model.device)
        peft_model_state_dict=model.state_dict()
        for k, v in adpater_state_dict.items():
                peft_model_state_dict[k] = v
        model.load_state_dict(peft_model_state_dict)
    model_name=base_model.split("/")[-1]
    train_data=load_from_disk(f"/mnt/prev_nas/jiangyinuo.jyn/code/evaluation/NER/id_train")
    train_data=format(train_data,f"{model_name}")
    val_data=load_from_disk(f"/mnt/prev_nas/jiangyinuo.jyn/code/evaluation/NER/id_test")
    val_data=format(val_data,f"{model_name}")
    train_data = (
       train_data.shuffle().map(tokenize, batched=True,
                      batch_size=1000)
    )
    train_data=train_data.remove_columns([col for col in train_data.column_names if col not in ['input_ids', 'labels']])
    val_data = (
        val_data.shuffle().map(tokenize, batched=True,
                      batch_size=1000)
    )
    val_data=val_data.remove_columns([col for col in train_data.column_names if col not in ['input_ids', 'labels']])

    train_data.set_format('torch')
    val_data.set_format('torch')
    
    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True

    trainer = MyCustomTrainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=1,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            # fp16=True,
            label_names=["labels"],
            logging_steps=10,
            optim="adamw_torch",
            evaluation_strategy="steps" if val_set_size > 0 else "no",
            save_strategy="steps",
            eval_steps=50 if val_set_size > 0 else None,
            save_steps=50,
            output_dir=output_dir,
            save_total_limit=3,
            load_best_model_at_end=True if val_set_size > 0 else False,
            ddp_find_unused_parameters=False if ddp else None,
           
        ),
        data_collator=data_collator,
         callbacks=[PeftSavingCallback]
    )
    model.config.use_cache = False

    old_state_dict = model.state_dict
    model.state_dict = (
        lambda self, *_, **__: my_get_peft_model_state_dict(
            self, old_state_dict()
        )
    ).__get__(model, type(model))

    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    trainer.train()
    
    model.save_pretrained(output_dir)
    
    new_state_dict = my_get_peft_model_state_dict(model)
    torch.save(new_state_dict,f"{output_dir}/adapter.pth")

if __name__ == "__main__":
    
    base_model="/mnt/gruntdata/rs_nas/workspace2/yinuo/models/OneKE"
    model_name=base_model.split("/")[-1]
    #model config
    expert_num_per_domain =4
    #train config
    task="NER"
    cluster=False
    today = datetime.now().date()
    out_base=f"/mnt/mednas/jiangyinuo.jyn/ckpt/{task}/{model_name}_expert_hie_gate"
    if not os.path.exists(out_base):
      os.mkdir(out_base)
    output_dir=os.path.join(out_base,str(today))
    adapter_paths=["/mnt/mednas/jiangyinuo.jyn/ckpt/NER/OneKE_co_0/2025-04-22_100/checkpoint-100",
"/mnt/mednas/jiangyinuo.jyn/ckpt/NER/OneKE_co_1/2025-04-22_100/checkpoint-100",
"/mnt/mednas/jiangyinuo.jyn/ckpt/NER/OneKE_co_2/2025-04-22_150/checkpoint-150",
"/mnt/mednas/jiangyinuo.jyn/ckpt/NER/OneKE_co_3/2025-04-22_150/checkpoint-150",] # example
    train(adapter_paths,base_model=base_model,output_dir=output_dir,cluster=cluster,expert_num_per_domain=expert_num_per_domain,)