import torch
import wandb
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from lorra_finetune.src.args import ModelArguments, TrainingArguments, LoraArguments, LorraArguments
from lorra_finetune.src.train_val_datasets import load_tqa_sentences, HalluWithAnsDataset, AlpacaSupervisedDataset
from tmt.bean import DefaultArgument
from ext.train_normal import DollyDataProvider

def dolly_fmt_fun(tokenizer,lorra_args):
    provider = DollyDataProvider(lorra_args.user_tag, lorra_args.assistant_tag, tokenizer,
                                 data_file=lorra_args.lorra_train_date_file)
    train_dataset, _ = provider.genrate_dataloader(10000, 200)

    val_datasets = {
        "tqa": load_tqa_sentences(lorra_args.user_tag, lorra_args.assistant_tag),
        # "arc-e": load_arc_sentences(),
    }
    return train_dataset, val_datasets

def lorra_fmt_fun(tokenizer,lorra_args):
    if lorra_args.lorra_train_data == "custom":
        train_dataset = HalluWithAnsDataset(tokenizer=tokenizer, num_examples=30000, lorra_args=lorra_args,
                                            data_files=lorra_args.lorra_train_date_file)
    elif lorra_args.lorra_train_data == "alpaca":
        train_dataset = AlpacaSupervisedDataset(tokenizer=tokenizer, num_examples=10000, lorra_args=lorra_args)
    else:
        raise ValueError("training data set value error!")
    print(f"LoRRA using training dataset: {lorra_args.lorra_train_data}")

    val_datasets = {
        "tqa": load_tqa_sentences(lorra_args.user_tag, lorra_args.assistant_tag),
        # "arc-e": load_arc_sentences(),
    }
    return train_dataset, val_datasets

def fine_tune_proc(data_prepare_fun, trainer_constructor,use_online_pos=True, use_online_neg=True):
    device_map = "cuda:0"
    parser = HfArgumentParser((ModelArguments, TrainingArguments, LoraArguments, LorraArguments, DefaultArgument))
    model_args, training_args, lora_args, lorra_args, def_args = parser.parse_args_into_dataclasses()

    wandb.login()
    wandb.init(
        project=f"{training_args.project_info}",
        config={"model_args": model_args, "training_args": training_args, "lora_args": lora_args,
                "lorra_args": lorra_args},
        name=training_args.run_name,
        mode=training_args.wandb_mode
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        device_map=device_map,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        use_fast=False,
        trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.unk_token

    lorra_target_layers = [int(layer) for layer in lorra_args.target_layers.split(",")]  # target representations
    lora_layers_to_transform = list(range(lorra_target_layers[-1] + 1))  # LoRA layers
    lora_config = LoraConfig(
        r=lora_args.lora_r,
        lora_alpha=lora_args.lora_alpha,
        target_modules=lora_args.lora_target_modules,
        lora_dropout=lora_args.lora_dropout,
        bias=lora_args.lora_bias,
        layers_to_transform=lora_layers_to_transform,
        task_type="CAUSAL_LM",
    )

    peft_model = get_peft_model(model, lora_config, adapter_name="center")
    # load adapters
    if use_online_pos:
        peft_model.load_adapter(def_args.pos_adapter_path, adapter_name="pos")

    if use_online_neg:
        peft_model.load_adapter(def_args.neg_adapter_path, adapter_name="neg")

    peft_model.set_adapter("center")
    peft_model.print_trainable_parameters()

    train_dataset, val_datasets = data_prepare_fun(tokenizer, lorra_args)

    if training_args.gradient_checkpointing:
        peft_model.enable_input_require_grads()

    trainer = trainer_constructor(model=peft_model, tokenizer=tokenizer, lorra_args=lorra_args ,args=training_args,
        train_dataset=train_dataset,  val_dataset=val_datasets, def_args=def_args)
    peft_model.config.use_cache = False

    trainer.train()
