import os
import sys
from typing import List
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoModel
import fire
import numpy as np
import torch
import transformers
from transformers import Trainer
from datasets import load_dataset
from accelerate import Accelerator
import wandb
import copy
# accelerator = Accelerator(device_placement=True, mixed_precision= "fp16" , cpu=False)
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""
os.environ["https_proxy"] ="http://10.10.1.3:10000"
os.environ["http_proxy"] ="http://10.10.1.3:10000"
WANDB__SERVICE_WAIT=300
os.environ["WANDB__SERVICE_WAIT"] = str(WANDB__SERVICE_WAIT)
layers=list(range(32))
def get_kqvogud_list(model,n=layers):
    dicts = {}
    qkvogud_list = []
    n=[str(i) for i in n]
    named_parameters=list(model.named_parameters())[:-1]
    for i in named_parameters:
        l = i[0].split(".")
        if l[2] in n:
            key = l[-2] + l[-1]
            dicts[key] = []
    for i, j in named_parameters:
        l = i.split(".")
        if l[2] in n:
            key = l[-2] + l[-1]
            dicts[key].append(j)
    for q,k,v,o,g,d,u,n1,n2 in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
            ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight'],
            dicts['input_layernormweight'],dicts['post_attention_layernormweight']):
        qkvogud_list.append([q,k,v,o,g,d,u,n1,n2])
    special_list=[model.state_dict()['lm_head.weight'],model.state_dict()['model.embed_tokens.weight'],model.state_dict()['model.norm.weight']]
    return qkvogud_list,special_list
def cosine_similarity(model,kqvogud_list0,special_list0):
    qkvogud_list,special_list=get_kqvogud_list(model)
    cos_sum=torch.tensor(0.0).to('cuda:1')
    len=torch.tensor(0).to('cuda:1')
    for qkvogud,index in zip(qkvogud_list,layers):
        for i,j in zip(qkvogud,kqvogud_list0[index]):
            c=index%4
            flat_i = torch.flatten(i.to(f'cuda:{c}'))
            flat_j = torch.flatten(j.to(f'cuda:{c}'))
            cos_sum+=torch.nn.functional.cosine_similarity(flat_i, flat_j,0).to('cuda:1')
            len+=1
    for i,j in zip(special_list,special_list0):
        flat_i = torch.flatten(i.to('cuda:3'))
        flat_j = torch.flatten(j.to('cuda:3'))
        cos_sum+=torch.nn.functional.cosine_similarity(flat_i, flat_j,0).to('cuda:1')
        len+=1
    cos=(cos_sum)/len
    return cos,torch.abs(cos)
class LLamaTrainer(Trainer):
    def __init__(self,*args, **kwargs):
        super().__init__(*args, **kwargs)
        # self.model.requires_grad_(False)
        model_llama = AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf")
        model_llama.requires_grad_(False)
        model_llama.to('cpu')
        self.kqvogud_list0,self.special_list0=get_kqvogud_list(model_llama)
        # self.kqvogud_list0=copy.deepcopy(qkvogud_list0)
        # self.special_list0=copy.deepcopy(special_list0)
        # self.model.requires_grad_(True)
        self.cos_sim_sum = []
        self.oringin_loss_sum=[]
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
        if self.control.should_log:

            logs: Dict[str, float] = {}

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

            # reset tr_loss to zero
            tr_loss -= tr_loss

            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
            logs["learning_rate"] = self._get_learning_rate()
            logs["cos_sim"] = torch.mean(torch.tensor(self.cos_sim_sum)).item()
            self.cos_sim_sum=[]
            logs["oringin_loss"] = torch.mean(torch.tensor(self.oringin_loss_sum)).item()
            self.oringin_loss_sum=[]
            self._total_loss_scalar += tr_loss_scalar
            self._globalstep_last_logged = self.state.global_step
            self.store_flos()

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
            if isinstance(self.eval_dataset, dict):
                metrics = {}
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
                    dataset_metrics = self.evaluate(
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
                    metrics.update(dataset_metrics)
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
            self._report_to_hp_search(trial, self.state.global_step, metrics)

            # Run delayed LR scheduler now that metrics are populated
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                metric_to_check = self.args.metric_for_best_model
                if not metric_to_check.startswith("eval_"):
                    metric_to_check = f"eval_{metric_to_check}"
                self.lr_scheduler.step(metrics[metric_to_check])

        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)
    def compute_loss(self,model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.
        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss1 = self.label_smoother(outputs, labels, shift_labels=True)
                loss2 = cosine_similarity(model0=model0,model=model)
                loss = loss1+loss2
            else:
                loss1 = self.label_smoother(outputs, labels)
                loss2 = cosine_similarity(model=model)
                loss = loss1+loss2
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            cos,loss2 = cosine_similarity(model,self.kqvogud_list0,self.special_list0)
            self.cos_sim_sum.append(cos)
            self.oringin_loss_sum.append(outputs["loss"])
            # wandb.log({"cos_sim":loss2}, step=self.state.global_step)
            # wandb.log({"oringin_loss":outputs["loss"]}, step=self.state.global_step)
            scale=1
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"]+scale*loss2.to('cuda:0') if isinstance(outputs, dict) else outputs[0]+scale*loss2.to('cuda:0')
            # loss = loss2 if isinstance(outputs, dict) else loss2
        return (loss, outputs) if return_outputs else loss
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer

from utils.prompter import Prompter


def train(
    # model/data params
    base_model: str = "",  # the only required argument
    data_path: str = "yahma/alpaca-cleaned",
    output_dir: str = "./lora-alpaca",
    # training hyperparams
    batch_size: int = 128,
    micro_batch_size: int = 4,
    num_epochs: int = 3,
    learning_rate: float = 3e-4,
    cutoff_len: int = 256,
    val_set_size: int = 2000,
    # lora hyperparams
    # lora_r: int = 8,
    # lora_alpha: int = 16,
    # lora_dropout: float = 0.05,
    # lora_target_modules: List[str] = [
    #     "q_proj",
    #     "v_proj",
    # ],
    # llm hyperparams
    train_on_inputs: bool = True,  # if False, masks out inputs in loss
    add_eos_token: bool = False,
    group_by_length: bool = False,  # faster, but produces an odd training loss curve
    # wandb params
    wandb_project: str = "",
    wandb_run_name: str = "",
    wandb_watch: str = "",  # options: false | gradients | all
    wandb_log_model: str = "",  # options: false | true
    resume_from_checkpoint: str = '/home/byzeng/project/alpaca-lora-main/alpaca_cosmm/checkpoint-777/',  # either training checkpoint or final adapter
    prompt_template_name: str = "alpaca",  # The prompt template to use, will default to alpaca.
):
    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        print(
            f"Training Alpaca-LoRA model with params:\n"
            f"base_model: {base_model}\n"
            f"data_path: {data_path}\n"
            f"output_dir: {output_dir}\n"
            f"batch_size: {batch_size}\n"
            f"micro_batch_size: {micro_batch_size}\n"
            f"num_epochs: {num_epochs}\n"
            f"learning_rate: {learning_rate}\n"
            f"cutoff_len: {cutoff_len}\n"
            f"val_set_size: {val_set_size}\n"
            # f"lora_r: {lora_r}\n"
            # f"lora_alpha: {lora_alpha}\n"
            # f"lora_dropout: {lora_dropout}\n"
            # f"lora_target_modules: {lora_target_modules}\n"
            f"train_on_inputs: {train_on_inputs}\n"
            f"add_eos_token: {add_eos_token}\n"
            f"group_by_length: {group_by_length}\n"
            f"wandb_project: {wandb_project}\n"
            f"wandb_run_name: {wandb_run_name}\n"
            f"wandb_watch: {wandb_watch}\n"
            f"wandb_log_model: {wandb_log_model}\n"
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
            f"prompt template: {prompt_template_name}\n"
        )
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
    gradient_accumulation_steps = batch_size // micro_batch_size

    prompter = Prompter(prompt_template_name)

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model

    model = LlamaForCausalLM.from_pretrained(
        base_model,
        # load_in_8bit=True,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
    )

    # qkvo,gud=get_kqvo_gud(model)
    # torch.save(qkvo,'/home/byzeng/project/alpaca-lora-main/llama_qkvo.pt')
    # torch.save(gud,'/home/byzeng/project/alpaca-lora-main/llama_gud.pt')
    tokenizer = LlamaTokenizer.from_pretrained(base_model)

    tokenizer.pad_token_id = (
        0  # unk. we want this to be different from the eos token
    )
    tokenizer.padding_side = "left"  # Allow batched inference

    def tokenize(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
        )
        if (
            result["input_ids"][-1] != tokenizer.eos_token_id
            and len(result["input_ids"]) < cutoff_len
            and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        return result

    def generate_and_tokenize_prompt(data_point):
        full_prompt = prompter.generate_prompt(
            data_point["instruction"],
            data_point["input"],
            data_point["output"],
        )
        tokenized_full_prompt = tokenize(full_prompt)
        if not train_on_inputs:
            user_prompt = prompter.generate_prompt(
                data_point["instruction"], data_point["input"]
            )
            tokenized_user_prompt = tokenize(
                user_prompt, add_eos_token=add_eos_token
            )
            user_prompt_len = len(tokenized_user_prompt["input_ids"])

            if add_eos_token:
                user_prompt_len -= 1

            tokenized_full_prompt["labels"] = [
                -100
            ] * user_prompt_len + tokenized_full_prompt["labels"][
                user_prompt_len:
            ]  # could be sped up, probably
        return tokenized_full_prompt

    # model = prepare_model_for_int8_training(model)

    # config = LoraConfig(
    #     r=lora_r,
    #     lora_alpha=lora_alpha,
    #     target_modules=lora_target_modules,
    #     lora_dropout=lora_dropout,
    #     bias="none",
    #     task_type="CAUSAL_LM",
    # )
    # model = get_peft_model(model, config)
    # model = accelerator.prepare(model)
    if data_path.endswith(".json") or data_path.endswith(".jsonl"):
        data = load_dataset("json", data_files=data_path)
    else:
        data = load_dataset(data_path)

    # if resume_from_checkpoint:
    #     # Check the available weights and load them
    #     checkpoint_name = os.path.join(
    #         resume_from_checkpoint, "pytorch_model.bin"
    #     )  # Full checkpoint
    #     if not os.path.exists(checkpoint_name):
    #         checkpoint_name = os.path.join(
    #             resume_from_checkpoint, "adapter_model.bin"
    #         )  # only LoRA model - LoRA config above has to fit
    #         resume_from_checkpoint = (
    #             False  # So the trainer won't try loading its state
    #         )
    #     # The two files above have a different name depending on how they were saved, but are actually the same.
    #     if os.path.exists(checkpoint_name):
    #         print(f"Restarting from {checkpoint_name}")
    #         adapters_weights = torch.load(checkpoint_name)
    #         set_peft_model_state_dict(model, adapters_weights)
    #     else:
    #         print(f"Checkpoint {checkpoint_name} not found")

    # model.print_trainable_parameters()  # Be more transparent about the % of trainable params.

    if val_set_size > 0:
        train_val = data["train"].train_test_split(
            test_size=val_set_size, shuffle=True, seed=42
        )
        train_data = (
            train_val["train"].shuffle().map(generate_and_tokenize_prompt)
        )
        val_data = (
            train_val["test"].shuffle().map(generate_and_tokenize_prompt)
        )
    else:
        train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
        val_data = None

    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True

    trainer = LLamaTrainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            # warmup_steps=0,
            warmup_steps=100,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            # fp16=True,
            logging_steps=1,
            optim="adamw_torch",
            # evaluation_strategy="steps" if val_set_size > 0 else "no",
            evaluation_strategy="epoch" if val_set_size > 0 else "no",
            # save_strategy="steps",
            save_strategy="epoch",
            eval_steps=2000 if val_set_size > 0 else None,
            # save_steps=400,
            output_dir=output_dir,
            save_total_limit=30,
            load_best_model_at_end=True if val_set_size > 0 else False,
            ddp_find_unused_parameters=False if ddp else None,
            group_by_length=group_by_length,
            report_to="wandb" if use_wandb else None,
            run_name=wandb_run_name if use_wandb else None,
            # hub_model_id="boyizeng123/alpaca_7b_16bf",
            # push_to_hub=True,
        ),
        data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
    )
    model.config.use_cache = False

    # old_state_dict = model.state_dict
    # model.state_dict = (
    #     lambda self, *_, **__: get_peft_model_state_dict(
    #         self, old_state_dict()
    #     )
    # ).__get__(model, type(model))

    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    trainer.train(resume_from_checkpoint='/home/byzeng/project/alpaca-lora-main/alpaca_cosmm/checkpoint-777/')
    
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    # trainer.push_to_hub(hub_model_id="boyizeng123/alpaca_7b_16bf")

    print(
        "\n If there's a warning about missing keys above, please disregard :)"
    )


if __name__ == "__main__":
    fire.Fire(train)
