import gc
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import transformers
from transformers import Trainer, BitsAndBytesConfig #, deepspeed
import torch
from lorra_finetune.src.train_val_datasets import AlpacaSupervisedDataset,HalluWithAnsDataset, load_tqa_sentences, load_arc_sentences, get_logprobs_accuracy
import wandb

from lorra_finetune.src.args import (
    ModelArguments,
    TrainingArguments,
    LoraArguments,
    LorraArguments,
)
torch.manual_seed(0)

class LoRRaEngTrainer(Trainer):

    def __init__(self,model, tokenizer, lorra_args:LorraArguments, args:TrainingArguments, train_dataset, val_datasets):

        super(LoRRaEngTrainer, self).__init__(
            model=model, tokenizer=tokenizer,
            args=args, train_dataset=train_dataset)

        self.max_tqa = 0
        self.model_idx = 0
        self.lorra_args = lorra_args
        self.training_args = args
        self.lora_target_layers = [int(layer) for layer in lorra_args.target_layers.split(",")]
        self.val_datasets = val_datasets
        if self.lorra_args.hidden_tk_len > 0:
            self.min_length = self.lorra_args.hidden_tk_len
        else:
            self.min_length = self.lorra_args.max_res_len


    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")

        assert input_ids.shape[1] == 3

        orig_input_ids = input_ids[:, 0]
        pos_input_ids = input_ids[:, 1]
        neg_input_ids = input_ids[:, 2]

        orig_attention_mask = attention_mask[:, 0]
        pos_attention_mask = attention_mask[:, 1]
        neg_attention_mask = attention_mask[:, 2]



        response_attention_mask = orig_attention_mask[:, -self.min_length:].repeat(len(self.lora_target_layers), 1, 1).unsqueeze(-1)

        with model.disable_adapter():
            model.eval()
            with torch.no_grad():
                orig_outputs = model(
                    input_ids=orig_input_ids,
                    attention_mask=orig_attention_mask,
                    output_hidden_states=True
                )['hidden_states']
                orig_hidden = [orig_outputs[l][:, -self.min_length:].detach() for l in self.lora_target_layers]
                pos_outputs = model(
                    input_ids=pos_input_ids,
                    attention_mask=pos_attention_mask,
                    output_hidden_states=True
                )['hidden_states']
                neg_outputs = model(
                    input_ids=neg_input_ids,
                    attention_mask=neg_attention_mask,
                    output_hidden_states=True
                )['hidden_states']
                direction_hidden = [pos_outputs[l][:, -self.min_length:].detach() - \
                                    neg_outputs[l][:, -self.min_length:].detach() \
                                    # + beta * torch.tensor(pca_directions[l - len(pca_directions)], device=model.device, dtype=torch.float16) \
                                    for l in self.lora_target_layers]
                target_hidden = torch.stack([orig_hidden[i] + self.lorra_args.lorra_alpha * direction_hidden[i] for i in
                                             range(len(self.lora_target_layers))]) * response_attention_mask

                del orig_outputs, pos_outputs, neg_outputs, orig_hidden, direction_hidden
                gc.collect()
                torch.cuda.empty_cache()

        model.train()
        lora_outputs = model(
            input_ids=orig_input_ids,
            attention_mask=orig_attention_mask,
            output_hidden_states=True
        )['hidden_states']
        lora_hidden = torch.stack([lora_outputs[l][:, -self.min_length:] for l in self.lora_target_layers]) * response_attention_mask

        # loss_fct = torch.nn.MSELoss()
        # loss = torch.norm(lora_hidden - target_hidden, dim=-1, p=2, dtype=torch.float).nanmean()
        loss = torch.norm(lora_hidden - target_hidden, dim=-1, p=2).nanmean()
        # wandb.log({"train_loss":loss})

        return (loss, lora_hidden) if return_outputs else loss

    def save_peft(self):
        self.model.save_pretrained(f"{self.training_args.output_dir}/{self.training_args.run_name}_{self.model_idx}",
                                   safe_serialization=False)
    def evaluate(self, eval_dataset=None, ignore_keys=None, sanity_check=False, **kwargs):
        self.model.eval()

        if sanity_check:
            print('Sanity check...')
        metrics = {}
        correct_rate = 0
        for val_set in self.val_datasets:
            questions, answer, labels = self.val_datasets[val_set]
            print(f'Evaluating {val_set} accuracy...')
            with torch.no_grad():
                acc = get_logprobs_accuracy(self.model, self.tokenizer, questions, answer, labels, self.training_args.per_device_eval_batch_size)
                acc_key = 'acc' if val_set == 'tqa' else 'acc_norm'
                metrics[f"{val_set}_accuracy"] = acc[acc_key]
                correct_rate = acc[acc_key]
        self.model.train()
        print("===Eval results===")
        print(metrics)
        wandb.log(metrics)

        if correct_rate > self.max_tqa:
            self.max_tqa = correct_rate
            self.save_peft()
            self.model_idx += 1

        return metrics
    # user real test
    # def evaluate(self, eval_dataset=None, ignore_keys=None, sanity_check=False, **kwargs):
    #
    #     correct_rate = test_tqa("/home/data/TruthfulQA/data/mc_task.json", self.model, self.tokenizer, printlog=False)
    #     if correct_rate > self.max_tqa:
    #         self.max_tqa = correct_rate
    #         self.model.save_pretrained(f"{self.training_args.output_dir}/{self.training_args.run_name}_{self.model_idx}",
    #                               safe_serialization=False)
    #         self.model_idx += 1
    #
    #     metrics = {"tqa_score":correct_rate,"total_epoch":self.state.epoch}
    #     print("===Eval results Real TQA ===")
    #     print(metrics)
    #     wandb.log(metrics)
    #
    #     return metrics

def main():
    parser = transformers.HfArgumentParser(
        (ModelArguments, TrainingArguments, LoraArguments, LorraArguments)
    )
    # split token
    (
        model_args,
        training_args,
        lora_args,
        lorra_args,
    ) = parser.parse_args_into_dataclasses()

    device_map = "cuda:0"
    ddp = False
    wandb.login()
    wandb.init(
        project=f"{training_args.project_info}",
        config={"model_args": model_args, "training_args": training_args, "lora_args": lora_args,
                "lorra_args": lorra_args},
        name=training_args.run_name,
        mode=training_args.wandb_mode
    )

    compute_dtype = (
        torch.float16
        if training_args.fp16
        else (torch.bfloat16 if training_args.bf16 else torch.float32)
    )
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        device_map=device_map,
        # add type
        #torch_dtype=compute_dtype
        torch_dtype = torch.bfloat16,
        trust_remote_code=True
    )

    lorra_target_layers = [int(layer) for layer in lorra_args.target_layers.split(",")] # target representations
    lora_layers_to_transform = list(range(lorra_target_layers[-1] + 1)) # LoRA layers

    lora_config = LoraConfig(
        r=lora_args.lora_r,
        lora_alpha=lora_args.lora_alpha,
        target_modules=lora_args.lora_target_modules,
        lora_dropout=lora_args.lora_dropout,
        bias=lora_args.lora_bias,
        layers_to_transform=lora_layers_to_transform,
        task_type="CAUSAL_LM",
    )

    if lora_args.q_lora:
        model = prepare_model_for_kbit_training(
            model, use_gradient_checkpointing=training_args.gradient_checkpointing
        )
        if not ddp and torch.cuda.device_count() > 1:
            # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
            model.is_parallelizable = True
            model.model_parallel = True

    model = get_peft_model(model, lora_config)

    if training_args.deepspeed is not None and training_args.local_rank == 0:
        model.print_trainable_parameters()

    if training_args.gradient_checkpointing:
        model.enable_input_require_grads()

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="left",
        use_fast=False,
        trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.unk_token

    if lorra_args.lorra_train_data == "custom":
        train_dataset = HalluWithAnsDataset(tokenizer=tokenizer, num_examples=30000, lorra_args=lorra_args,
                                            data_files=lorra_args.lorra_train_date_file)
    elif lorra_args.lorra_train_data == "alpaca":
        train_dataset = AlpacaSupervisedDataset(tokenizer=tokenizer, num_examples=10000, lorra_args=lorra_args)
    else:

        raise ValueError("training data set value error!")
    print(f"LoRRA using training dataset: {lorra_args.lorra_train_data}")

    if training_args.do_eval:
        val_datasets = {
            "tqa": load_tqa_sentences(lorra_args.user_tag, lorra_args.assistant_tag),
            #"arc-e": load_arc_sentences(),
        }
    else:
        val_datasets = {}

    trainer = LoRRaEngTrainer(
        model=model, tokenizer=tokenizer, lorra_args=lorra_args ,args=training_args,
        train_dataset=train_dataset,  val_datasets=val_datasets)
    model.config.use_cache = False

    print("LoRRA finetune...")
    trainer.train()
    # test_tqa("/home/data/TruthfulQA/data/mc_task.json",model,tokenizer, printlog=True)

if __name__ == "__main__":
    main()