import argparse
import sys
import pickle
import torch
import os
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_from_disk
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, default_data_collator
from peft import PeftModel, LoraConfig, get_peft_model
import textwrap
from fingerprint_utils import Fingerprint, extract_fingerprints, check_training_done, calculate_batch_kl_loss
from fingerprint_utils import add_pad_token
from peft import AutoPeftModelForCausalLM
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence

def parse_arguments():
    parser = argparse.ArgumentParser(description="Fingerprint training and adapter combination")
    parser.add_argument("--model_name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
    parser.add_argument("--model_folder", type=str, default=None)
    parser.add_argument("--fingerprint_adapter", type=str, default=None)
    parser.add_argument("--finetune_adapter", type=str, default=None)
    parser.add_argument("--logging_dir", type=str, default=None)
    parser.add_argument("--eval_dataset", type=str, default=None)
    args = parser.parse_args()
    return args

args = parse_arguments()
print(args)

training_args = TrainingArguments(output_dir="test_trainer", per_device_eval_batch_size=32, fp16=True, report_to="none")

fingerprint_model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
add_pad_token( fingerprint_model, tokenizer )
fingerprint_model = PeftModel.from_pretrained(fingerprint_model, args.fingerprint_adapter, adapter_name="fingerprint")
add_pad_token( fingerprint_model, tokenizer )

finetune_model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
add_pad_token( finetune_model, tokenizer )
finetune_model = PeftModel.from_pretrained(finetune_model, args.finetune_adapter, adapter_name="finetune")
add_pad_token( finetune_model, tokenizer )

dataset = load_from_disk(args.eval_dataset)
# pad dataset to the length of 512
def pad_dataset(example):
    example["input_ids"] = example["input_ids"] + [tokenizer.pad_token_id] * (512 - len(example["input_ids"]))
    if len(example["input_ids"]) != 512:
        print(f"input_ids length: {len(example['input_ids'])}")
    example["labels"] = example["labels"] + [-100] * (512 - len(example["labels"]))
    if len(example["labels"]) != 512:
        print(f"labels length: {len(example['labels'])}")
    return example

dataset = dataset.map(pad_dataset)

eval_results = []

trainer = Trainer(
    model=finetune_model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
)
print(f"Running evaluation on finetune model")
eval_print = trainer.evaluate()
eval_results.append(eval_print)
print(eval_print)

trainer = Trainer(
    model=fingerprint_model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
)

torch.cuda.empty_cache()

print(f"Running evaluation on fingerprint model")
eval_print = trainer.evaluate()
eval_results.append(eval_print)
print(eval_print)

# save eval_results to logging_dir
print(f"Saving eval results to {args.logging_dir}/eval_results.log")
with open(f"{args.logging_dir}/eval_results.log", "w") as f:
    f.write(str(eval_results))

# accelerate launch debug.py