import torch
import argparse
import datasets
from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig, AutoTokenizer
from influence import utils
from influence.main_influence import compute_val_grad_avg


def _pad_examples(example, tokenizer, max_length=1000):
    example["input_ids_chosen"] = example["input_ids_chosen"] + [tokenizer.pad_token_id] * (
        max_length - len(example["input_ids_chosen"])
    )
    example["attention_mask_chosen"] = example["attention_mask_chosen"] + [0] * (
        max_length - len(example["attention_mask_chosen"])
    )
    example["input_ids_rejected"] = example["input_ids_rejected"] + [tokenizer.pad_token_id] * (
        max_length - len(example["input_ids_rejected"])
    )
    example["attention_mask_rejected"] = example["attention_mask_rejected"] + [0] * (
        max_length - len(example["attention_mask_rejected"])
    )
    return example


def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="")
    parser.add_argument("--tokenizer_path", type=str, default="")
    # dataset
    parser.add_argument("--eval_dataset_dir", type=str, default="")
    parser.add_argument("--eval_dataset_name", type=str, default="")
    parser.add_argument("--val_names", type=str, default="")
    # quantization
    parser.add_argument("--load_in_4bit", type=bool, default=True)
    parser.add_argument("--bnb_4bit_quant_type", type=str, default="nf4")
    parser.add_argument("--use_bnb_nested_quant", type=bool, default=True)
    parser.add_argument("--torch_dtype", type=str, default="bfloat16")

    parser.add_argument("--train_batchsize", type=int, default=1)
    parser.add_argument("--eval_batchsize", type=int, default=1)
    parser.add_argument("--eval_partition", type=int, default=1000)
    return parser.parse_args()


if __name__ == "__main__":
    args = parser()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.load_in_4bit:
        compute_dtype = torch.float16
        if args.torch_dtype not in {"auto", None}:
            compute_dtype = getattr(torch, args.torch_dtype)
        print(compute_dtype)
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=args.use_bnb_nested_quant,
        )
    else:
        quantization_config = None
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_path, num_labels=1, quantization_config=quantization_config
    )
    # turn the lora weights to require_grad, to compute gradients and send model to device
    model = utils.prepare_model(model, device, args.model_path)
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    if model.config.pad_token_id is None:
        model.config.pad_token_id = model.config.eos_token_id
    # eval_dataset_paths = args.eval_dataset_dir + ["eval_dataset_longer_preferred", "eval_dataset_shorter_preferred", "eval_dataset_shorter_preferred_threshold_10", "eval_dataset_A", "eval_dataset_B", "eval_dataset_CD"]

    val_names = args.val_names.split(",")
    eval_dataset_paths = [args.eval_dataset_dir + f"/{args.eval_dataset_name}_{name}" for name in val_names]
    eval_datasets = []
    for eval_dataset_path in eval_dataset_paths:
        eval_datasets.append(datasets.load_from_disk(eval_dataset_path))

    #####
    # step 1. compute gradients for validation dataset and retrieve the average gradient for each weight memory : O(p)
    #####
    # split the eval_dataset into two, one where preferred responses is longer than rejected, and vice versa
    val_grad_avd_dicts = []
    for eval_dataset in eval_datasets:
        val_grad_avg_dict = compute_val_grad_avg(
            model, eval_dataset, args.eval_partition, args.eval_batchsize, utils.collate_fn
        )
        val_grad_avd_dicts.append(val_grad_avg_dict)
    # store val_grad_avg_dict
    for i, val_grad_avg_dict in enumerate(val_grad_avd_dicts):
        torch.save(val_grad_avg_dict, args.model_path + f"/val_grad_avg_dict_{val_names[i]}.pt")
