import torch
import argparse
import datasets
from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig
from influence.utils import collate_fn, prepare_model
from influence.main_influence import compute_lambda


def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tokenizer_path", type=str, default="")
    parser.add_argument("--model_path", type=str, default="")
    # data
    parser.add_argument(
        "--train_data_path",
        type=str,
        default="",
    )
    # quantization
    parser.add_argument("--load_in_4bit", type=bool, default=True)
    parser.add_argument("--bnb_4bit_quant_type", type=str, default="nf4")
    parser.add_argument("--use_bnb_nested_quant", type=bool, default=True)
    parser.add_argument("--torch_dtype", type=str, default="bfloat16")

    parser.add_argument("--train_batchsize", type=int, default=1)
    parser.add_argument("--eval_batchsize", type=int, default=1)

    parser.add_argument("--eval_partition", type=int, default=1000)
    return parser.parse_args()


if __name__ == "__main__":
    args = parser()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.load_in_4bit:
        compute_dtype = torch.float16
        if args.torch_dtype not in {"auto", None}:
            compute_dtype = getattr(torch, args.torch_dtype)
        print(compute_dtype)
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=args.use_bnb_nested_quant,
        )
    else:
        quantization_config = None

    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_path, num_labels=1, quantization_config=quantization_config
    )
    # turn the lora weights to require_grad, to compute gradients and send model to device
    model = prepare_model(model, device, args.model_path)
    # get dataset retrieves the dataset after padding the dataset to max_length
    train_dataset = datasets.load_from_disk(args.train_data_path)
    # train_dataset, eval_dataset = get_dataset(args.model_path, args.tokenizer_path)
    if model.config.pad_token_id is None:
        model.config.pad_token_id = model.config.eos_token_id

    #####
    # step 2. compute r_l for each layer and store them, use batches of trainin dataset
    #####
    # step 2.1 compute lambda_const for each layer
    lambda_dict = compute_lambda(model, train_dataset, args.train_batchsize, collate_fn)
    # save as pt
    torch.save(lambda_dict, args.model_path + "/lambda_dict.pt")
