import torch
import argparse
import datasets
from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig
from influence.utils import collate_fn, prepare_model
from influence.main_influence import compute_r_l


def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="")
    parser.add_argument("--tokenizer_path", type=str, default="Qwen/Qwen1.5-1.8B")
    # data
    parser.add_argument("--train_data_path", type=str, default="", required=True)
    parser.add_argument("--val_names", type=str, default="A,B,CD")
    # quantization
    parser.add_argument("--load_in_4bit", type=bool, default=True)
    parser.add_argument("--bnb_4bit_quant_type", type=str, default="nf4")
    parser.add_argument("--use_bnb_nested_quant", type=bool, default=True)
    parser.add_argument("--torch_dtype", type=str, default="bfloat16")

    parser.add_argument("--train_batchsize", type=int, default=1)
    parser.add_argument("--eval_batchsize", type=int, default=1)
    return parser.parse_args()


if __name__ == "__main__":
    args = parser()
    val_names = args.val_names.split(",")
    print(f"val_names are {val_names}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.load_in_4bit:
        compute_dtype = torch.float16
        if args.torch_dtype not in {"auto", None}:
            compute_dtype = getattr(torch, args.torch_dtype)
        print(compute_dtype)
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=args.use_bnb_nested_quant,
        )
    else:
        quantization_config = None
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_path, num_labels=1, quantization_config=quantization_config, use_flash_attention_2=True
    )
    # turn the lora weights to require_grad, to compute gradients and send model to device
    model = prepare_model(model, device, args.model_path)
    # get dataset retrieves the dataset after padding the dataset to max_length
    train_dataset = datasets.load_from_disk(args.train_data_path)
    if model.config.pad_token_id is None:
        model.config.pad_token_id = model.config.eos_token_id

    # first load (1) val_grad_avg_dict, and (2) lambda_dict
    val_names = args.val_names.split(",")
    val_grad_avg_dict_list = []
    for val_name in val_names:
        val_grad_avg_dict_list.append(torch.load(args.model_path + f"/val_grad_avg_dict_{val_name}.pt"))
    lambda_dict = torch.load(args.model_path + "/lambda_dict.pt")

    r_l_dict_list = compute_r_l(
        model, train_dataset, val_grad_avg_dict_list, lambda_dict, args.train_batchsize, collate_fn
    )
    # store val_grad_avg_dict
    for i, r_l_dict in enumerate(r_l_dict_list):
        torch.save(r_l_dict, args.model_path + f"/r_l_dict_{val_names[i]}.pt")
