
###pure token  routing
import sys,os

from typing import List
from prompter import format_squad
from dataclasses import dataclass

from datasets import Dataset
import torch

from datasets import load_from_disk
import argparse
from datetime import datetime
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from  moepeft import( 
    find_and_replace,
    
    )

from typing import List
from prompter import format
from dataclasses import dataclass

import torch
import transformers
from dataclasses import dataclass
from transformers import LlamaForCausalLM, LlamaTokenizer,AutoTokenizer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_callback import TrainerCallback
from utils import add_mlp_gate,MyCustomTrainer,freeze_model,MyDataCollator
IGNORE_INDEX = -100
import torch.nn as nn
from typing import Optional


def my_get_peft_model_state_dict(model, state_dict=None,):
   
    if state_dict is None:
        state_dict = model.state_dict()
   
    to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "mlp.gate.GateL" in k}
   
    return to_return

def get_moe_model(model,expert_num,cluster,adapter_name):
    
    adapter_path="/mnt/mednas/jiangyinuo.jyn/HiMoLE/src/config"
    lora_config=LoraConfig.from_pretrained(adapter_path)
    find_and_replace(model,adapter_name,lora_config,expert_num,cluster,layer_type="moe")
    model.router_aux_loss_coef=0.001
    model.num_experts=expert_num
    return model

class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        # peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        # kwargs["model"].save_pretrained(peft_model_path)
        # if not os.path.exists(checkpoint_folder):
        #     os.makedirs(checkpoint_folder)
        new_state_dict = my_get_peft_model_state_dict(kwargs["model"])
        torch.save(new_state_dict,f"{checkpoint_folder}/adapter.pth")

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        os.remove(pytorch_model_path) if os.path.exists(pytorch_model_path) else None
        
def train(part,
    # model/data params
    base_model: str = "",  # the only required argument
    output_dir: str = "./lora-alpaca",
    cluster: bool = False,
    expert_num: int = 4,
    # training hyperparams
    batch_size: int = 128,
    micro_batch_size: int = 4,
    num_epochs: int = 10,
    learning_rate: float = 3e-4,
    cutoff_len: int = 1024,
    val_set_size: int = 2000,
):
    gradient_accumulation_steps = batch_size // micro_batch_size
    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    # Check if parameter passed or if set within environ
 
    model = LlamaForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map=device_map,
    )
  
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    tokenizer.padding_side = "left"  # Allow batched inference
    data_collator = MyDataCollator(tokenizer=tokenizer)
    def tokenize(data): 
        sources=[s for s in data["sentence"]]
        targets=[s for s in data["output"]]
        tokenized_sources = tokenizer(sources,return_attention_mask=False)
        tokenized_targets = tokenizer(targets,return_attention_mask=False,add_special_tokens=False)
        all_input_ids = []
        all_labels = []
        for s,t in zip(tokenized_sources['input_ids'],tokenized_targets['input_ids']):
          t=t+[tokenizer.eos_token_id]
          if len(s + t)>cutoff_len and (s + t)[:cutoff_len][-1]!=tokenizer.eos_token_id:

            input_ids = torch.LongTensor((s + t)[:cutoff_len-1]+[tokenizer.eos_token_id])
            labels = torch.LongTensor(([IGNORE_INDEX] * len(s) + t)[:cutoff_len-1]+[tokenizer.eos_token_id])
          else:
            input_ids = torch.LongTensor(s + t)[:cutoff_len]
            labels = torch.LongTensor([IGNORE_INDEX] * len(s) + t)[:cutoff_len]
          assert len(input_ids) == len(labels)
          assert  input_ids[-1]==tokenizer.eos_token_id
          assert  labels[-1]==tokenizer.eos_token_id
          all_input_ids.append(input_ids)
          all_labels.append(labels)
        results = {'input_ids':all_input_ids, 'labels': all_labels}
        return results

    model = prepare_model_for_int8_training(model)
    adapter_name=f"id_{part}"
    model=get_moe_model(model,expert_num,cluster,adapter_name)
    # freeze other part
    freeze_model(model)
    model=add_mlp_gate(model,expert_num)
    model_name="OneKE"
    train_data=load_from_disk(f"/mnt/prev_nas/jiangyinuo.jyn/code/evaluation/{task}/id_train")
    train_data=format(train_data,f"{model_name}")
    val_data=load_from_disk(f"/mnt/prev_nas/jiangyinuo.jyn/code/evaluation/{task}/id_test")
    val_data=format(val_data,f"{model_name}")
    train_data = (
       train_data.shuffle().map(tokenize, batched=True,
                      batch_size=1000)
    )
    train_data=train_data.remove_columns([col for col in train_data.column_names if col not in ['input_ids', 'labels']])
    val_data = (
        val_data.shuffle().map(tokenize, batched=True,
                      batch_size=1000)
    )
    val_data=val_data.remove_columns([col for col in train_data.column_names if col not in ['input_ids', 'labels']])

    train_data.set_format('torch')
    val_data.set_format('torch')
    
    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True

    trainer = MyCustomTrainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=100,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            # fp16=True,
            label_names=["labels"],
            logging_steps=10,
            optim="adamw_torch",
            evaluation_strategy="steps" if val_set_size > 0 else "no",
            save_strategy="steps",
            eval_steps=50 if val_set_size > 0 else None,
            save_steps=50,
            output_dir=output_dir,
            save_total_limit=3,
            load_best_model_at_end=True if val_set_size > 0 else False,
            ddp_find_unused_parameters=False if ddp else None,
           
        ),
        data_collator=data_collator,
         callbacks=[PeftSavingCallback]
    )
    model.config.use_cache = False

    old_state_dict = model.state_dict
    model.state_dict = (
        lambda self, *_, **__: my_get_peft_model_state_dict(
            self, old_state_dict()
        )
    ).__get__(model, type(model))

    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    
    model.save_pretrained(output_dir)
    
    new_state_dict = my_get_peft_model_state_dict(model)
    torch.save(new_state_dict,f"{output_dir}/adapter.pth")


    print(
        "\n If there's a warning about missing keys above, please disregard :)"
    )


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description='training KCGs')
    # parser.add_argument('--ds-config', type=str, required=True)
    parser.add_argument('--part', type=str, default="")
    parser.add_argument('--expert-num', type=int, default=4)
    parser.add_argument('--ckpt', type=str, default=None)
    args = parser.parse_args()
    print(args)
    base_model="/mnt/gruntdata/rs_nas/workspace2/yinuo/models/OneKE"
    model_name=base_model.split("/")[-1]
    today = datetime.now().date()
    task="NER"
    cluster=False
    part=args.part
    resume_from_checkpoint=args.ckpt
    out_base=f"/mnt/mednas/jiangyinuo.jyn/ckpt/{task}/{model_name}_expert_co{part}"
    if not os.path.exists(out_base):
      os.mkdir(out_base)
    output_dir=os.path.join(out_base,str(today))
    expert_num=args.expert_num
    train(part,base_model=base_model,output_dir=output_dir,cluster=cluster,expert_num=expert_num,resume_from_checkpoint=resume_from_checkpoint)