from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
from tqdm.notebook import tqdm
import numpy as np
import torch
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss,MSELoss, NLLLoss, KLDivLoss
import json
import random
import matplotlib.pyplot as plt
import transformers
import sys
sys.path.append('../.')
from utils.lora import LoRANetwork
from utils.metrics import get_wmdp_accuracy, get_mmlu_accuracy, get_hp_accuracy

transformers.utils.logging.set_verbosity(transformers.logging.CRITICAL)


# model_id = 'meta-llama/Meta-Llama-3-8B'
# model_id = 'meta-llama/Llama-2-7b-hf'
# model_id = 'meta-llama/Llama-2-7b-chat-hf'
model_id = 'mistralai/Mistral-7B-v0.1'
# model_id = 'HuggingFaceH4/zephyr-7b-beta'
# model_id = 'EleutherAI/pythia-2.8b-deduped'
# model_id = 'microsoft/Phi-3-mini-128k-instruct'
# model_id = 'microsoft/Llama2-7b-WhoIsHarryPotter'

if 'mistralai' in model_id:
    model_card = 'mistral'
if 'Llama-3' in model_id:
    model_card = 'llama3'
if 'Llama-2-7b-chat' in model_id:
    model_card = 'llama2chat'
if 'Llama-2-7b-hf' in model_id:
    model_card = 'llama2'
if 'zephyr' in model_id:
    model_card = 'zephyr'
if 'mistralai' in model_id:
    model_card = 'mistral'
print(model_card)
device = 'cuda:0'
dtype= torch.float32
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             # use_flash_attention_2="flash_attention_2",
                                             torch_dtype=dtype)
model = model.to(device)
model.requires_grad_(False)
tokenizer = AutoTokenizer.from_pretrained(model_id, 
                                          use_fast=False)

tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.mask_token_id = tokenizer.eos_token_id
tokenizer.sep_token_id = tokenizer.eos_token_id
tokenizer.cls_token_id = tokenizer.eos_token_id


harry_acc = get_hp_accuracy(model, tokenizer,  network=None, batch_size = 5, dtype = dtype, device = 'cuda:0')
print(f"Accuracy for Harry Potter : {np.round(harry_acc, 3)}")
accs, final_acc = get_wmdp_accuracy(model, tokenizer,  network=None, batch_size = 5, dtype = dtype, device = 'cuda:0')
acc = get_mmlu_accuracy(model, tokenizer,  network=None, data_dir='../data/mmlu/test', batch_size = 5, dtype = dtype, device = 'cuda:0', verbose=True)
print(f'\n\n')
print(f"MODEL {model_id} | {model_card}")
print(f"Accuracy for Harry Potter : {np.round(harry_acc, 3)}")
print(f"Accuracy for CYBER: {np.round(accs[1],3)}")
print(f"Accuracy for BIO: {np.round(accs[0],3)}")
print(f"Accuracy for MMLU : {np.round(acc, 3)}")