from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType,PeftModelForCausalLM,PromptTuningInit, PromptTuningConfig
from llm import LLM
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import argparse
import os
from trl import SFTTrainer
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments,
    AutoConfig
)
from transformers import TrainerCallback, TrainerState, TrainerControl
import numpy as np
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
            

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--base_model_name",
        type=str,
        default="lmsys/vicuna-13b-v1.5",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default="gold_distraction9",
    )

    parser.add_argument(
        "--seed",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--peft",
        type=str,
        default='None',
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=1e-5,
    )
    parser.add_argument(
        "--device",
        type=str,
        default='1',
    )
    return parser.parse_args()

args=parse_args()
if args.base_model_name=='lmsys/vicuna-7b-v1.5':
    model_name_or_path='/data/somebody/data/huggingface_cache/hub/models--lmsys--vicuna-7b-v1.5/snapshots/3321f76e3f527bd14065daf69dad9344000a201d'
elif args.base_model_name=='lmsys/vicuna-13b-v1.5':
    model_name_or_path='/data/somebody/data/huggingface_cache/hub/models--lmsys--vicuna-13b-v1.5/snapshots/c8327bf999adbd2efe2e75f6509fa01436100dc2'
else:
    model_name_or_path = args.base_model_name

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
os.environ['CUDA_LAUNCH_BLOCKING']="1"


tokenizer_name_or_path = model_name_or_path

dataset_name=args.dataset_name
from datasets import load_dataset
# dataset = load_dataset("financial_phrasebank", "sentences_allagree")
# dataset = dataset["train"].train_test_split(test_size=0.1)
# dataset["validation"] = dataset["test"]
# del dataset["test"]

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token 
    
train_file=os.path.join(f'/data/somebody/data/datasets/power_of_noise/prompt_answer/train',dataset_name)+".json"
test_file=os.path.join(f'/data/somebody/data/datasets/power_of_noise/prompt_answer/test',dataset_name)+".json"
print(train_file)
print(test_file)
from prompt_dataset1 import QueryDataset,are_answers_matching
dataloader=QueryDataset(train_file)

train_dataset = load_dataset('json',data_files={"train":train_file,"test":test_file},split='train')
test_dataset = load_dataset('json',data_files={"train":train_file,"test":test_file},split='test')
if args.peft!='None':
    from generate_answers_llm import get_checkpoint_path
    check_point=os.path.join(f'/data/somebody/data/checkpoints/final/{args.base_model_name}/{args.peft}',dataset_name)
    check_point=get_checkpoint_path(check_point)
else:
    check_point=None
print("check point==",check_point)
model = LLM(
    model_name_or_path,args.dataset_name, quantization_bits=4, 
    model_max_length=4096,check_point=check_point,
)
tokenizer = model.tokenizer
def eval():
    from prompt_dataset1 import QueryDataset
    train_dataloader=QueryDataset(data_path=train_file)
    test_dataloader=QueryDataset(data_path=test_file)
    def get_acc(dataloader):
        answer_string_in_prompt = "Answer:"
        ans_match_after_norms=[]
        idx=0
        for step, batch in enumerate(tqdm(dataloader)):
            prompts = batch['prompt']
            answers=batch['answers']

            generated_output=model.generate(prompts,tokenizer)
            
            print('pred==',generated_output)
            generated_answers = []
            for output in generated_output:
                start = output.find(answer_string_in_prompt) + len(answer_string_in_prompt)
                response = output[start:].strip()
                generated_answers.append(response)

            for i in range(len(generated_answers)):
                answer=[answers]
                ans_match_after_norm: bool = are_answers_matching(generated_answers[0], answer)
                ans_match_after_norms.append(ans_match_after_norm)

                print('generated answers==',generated_answers[i])
                print('answers==',answer)

            if len(ans_match_after_norms)>=30:
                acc=1.0*np.sum(ans_match_after_norms)/len(ans_match_after_norms)
                if acc<0.1:
                    return acc
            
        
        acc=1.0*np.sum(ans_match_after_norms)/len(ans_match_after_norms)
        print("acc==",acc)
        return acc
    train_acc=get_acc(train_dataloader)
    test_acc=get_acc(test_dataloader)
    with open('result.txt','a') as f:
        content=f'{args.base_model_name} {args.base_model_name} {args.peft} {args.lr} {args.dataset_name} {train_acc} {test_acc}\n'
        f.write(content)

eval()

