import torch
import argparse
import datasets
from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig
from influence.utils import prepare_model, collate_fn
from influence.main_influence import compute_influence


def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tokenizer_path", type=str, default="")
    parser.add_argument("--model_path", type=str, default="")
    # data
    parser.add_argument("--train_data_path", type=str, default="", required=True)
    parser.add_argument("--val_names", type=str, default="")
    # quantization
    parser.add_argument("--load_in_4bit", type=bool, default=True)
    parser.add_argument("--bnb_4bit_quant_type", type=str, default="nf4")
    parser.add_argument("--use_bnb_nested_quant", type=bool, default=True)
    parser.add_argument("--torch_dtype", type=str, default="bfloat16")

    parser.add_argument("--train_batch_size", type=int, default=1)
    return parser.parse_args()


if __name__ == "__main__":
    args = parser()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.load_in_4bit:
        compute_dtype = torch.float16
        if args.torch_dtype not in {"auto", None}:
            compute_dtype = getattr(torch, args.torch_dtype)
        print(compute_dtype)
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=args.use_bnb_nested_quant,
        )
    else:
        quantization_config = None
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_path, num_labels=1, quantization_config=quantization_config
    )
    # turn the lora weights to require_grad, to compute gradients and send model to device
    model = prepare_model(model, device, args.model_path)
    # get dataset retrieves the dataset after padding the dataset to max_length
    train_dataset = datasets.load_from_disk(args.train_data_path)
    if model.config.pad_token_id is None:
        model.config.pad_token_id = model.config.eos_token_id
    # r_l_dict = torch.load(args.model_path + "/r_l_dict.pt")/
    val_names = args.val_names.split(",")
    r_l_dict_list = []
    for val_name in val_names:
        r_l_dict_list.append(torch.load(args.model_path + f"/r_l_dict_{val_name}.pt"))
    influence_dict_list = compute_influence(model, train_dataset, r_l_dict_list, args.train_batch_size, collate_fn)
    for i, influcence_dict in enumerate(influence_dict_list):
        torch.save(influcence_dict, args.model_path + f"/influence_dict_{val_names[i]}.pt")
