from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType,PeftModelForCausalLM,PromptTuningInit, PromptTuningConfig
from Rpeft_model import RPeftModel
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import argparse
import os
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments,
    AutoConfig
)
from transformers import TrainerCallback, TrainerState, TrainerControl
import numpy as np
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--base_model_name",
        type=str,
        default="meta-llama/Llama-2-7b-chat-hf",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default="gold_distraction3",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="nq_no_ret",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--percentage",
        type=float,
        default=0,
    )
    parser.add_argument(
        "--num_virtual_token",
        type=int,
        default=15,
    )
    parser.add_argument(
        "--num_document_token",
        type=int,
        default=3,
    )
    
    parser.add_argument(
        "--num_virtual_document",
        type=int,
        default=3,
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=5e-6,
    )
    parser.add_argument(
        "--total_epoch",
        type=int,
        default=5,
    )
    parser.add_argument(
        "--train_bert",
        type=str2bool,
        default=True,
    )
    parser.add_argument(
        "--device",
        type=str,
        default="0",
    )
    return parser.parse_args()

args=parse_args()
from get_lr import get_lr
args.lr=get_lr(args.base_model_name,'gold_distraction3',args.num_virtual_token)
print("lr==",args.lr)
args.output_dir=f'/home/somebody/codes/RAGE/RPrefix_tuning/checkpoints_prefix/lora_DP1/{args.base_model_name}/gold_distraction3/{args.train_bert}/{args.lr}/virtual_token_{args.num_virtual_token}/document_token_{args.num_document_token}/num_virtual_document_{args.num_virtual_document}'
if args.base_model_name=='lmsys/vicuna-7b-v1.5':
    model_name_or_path='/data/somebody/data/huggingface_cache/hub/models--lmsys--vicuna-7b-v1.5/snapshots/3321f76e3f527bd14065daf69dad9344000a201d'
else:
    model_name_or_path = args.base_model_name

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
os.environ['CUDA_LAUNCH_BLOCKING']="1"


tokenizer_name_or_path = model_name_or_path

dataset_name=args.dataset_name
from datasets import load_dataset
# dataset = load_dataset("financial_phrasebank", "sentences_allagree")
# dataset = dataset["train"].train_test_split(test_size=0.1)
# dataset["validation"] = dataset["test"]
# del dataset["test"]

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,padding_side="left", truncation_side="left",model_max_length=4096)
tokenizer.pad_token = tokenizer.eos_token 
llama_tokenizer=AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf',padding_side="left", truncation_side="left",model_max_length=4096)
tokenizer.chat_template=llama_tokenizer.chat_template
train_file=os.path.join(f'prompt_answer/train',dataset_name)+".json"
test_file=os.path.join(f'prompt_answer/test',dataset_name)+".json"
print(train_file)
print(test_file)
from prompt_dataset import QueryDataset,are_answers_matching
dataloader=QueryDataset(train_file)

train_dataset = load_dataset('json',data_files={"train":train_file,"test":test_file},split='train')
test_dataset = load_dataset('json',data_files={"train":train_file,"test":test_file},split='test')


#peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)
prompt_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=args.num_virtual_token,
    prompt_tuning_init_text="Answer the question based on the provided documents and feel free to ignore the irrelavant or distracting ones",
    tokenizer_name_or_path=model_name_or_path,
)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
model_config.max_seq_len=4096
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
    config=model_config,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map='auto',
)
from DPrompt import DPromptModel
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64
lora_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "ffn"
    ],
)

from peft.utils import _prepare_prompt_learning_config
model=get_peft_model(model,lora_config)
print(model)
model.print_trainable_parameters()
prompt_config.base_model_name_or_path=model_name_or_path
prompt_config=_prepare_prompt_learning_config(prompt_config,model_config.to_dict())

model=DPromptModel(args.base_model_name,model, prompt_config,train_bert=args.train_bert,num_document_token=args.num_document_token,num_virtual_document=args.num_virtual_document, percentage=args.percentage, adapter_name='default')
model.print_trainable_parameters()
model.load_prefix_encoder(args.output_dir)
# from llm import LLM
# model=LLM(model_name_or_path,'gold_only',quantization_bits=4, model_max_length=4096)





def eval(epoch):
    from prompt_dataset import QueryDataset
    train_dataloader=QueryDataset(data_path=train_file)
    test_dataloader=QueryDataset(data_path=test_file)
    def get_acc(dataloader):
        answer_string_in_prompt = "Answer:"
        ans_match_after_norms=[]
        idx=0
        for step, batch in enumerate(tqdm(dataloader)):
            #model.eval()
            prompts = batch['prompt']
            answers=batch['answers']

            generated_output=model.generate_answer(prompts,tokenizer)
            
            print('pred==',generated_output)
            generated_answers = []
            for output in generated_output:
                start = output.find(answer_string_in_prompt) + len(answer_string_in_prompt)
                response = output[start:].strip()
                generated_answers.append(response)

            for i in range(len(generated_answers)):
                answer=[answers]
                ans_match_after_norm: bool = are_answers_matching(generated_answers[0], answer)
                ans_match_after_norms.append(ans_match_after_norm)

                print('generated answers==',generated_answers[i])
                print('answers==',answer)

            if len(ans_match_after_norms)>=30:
                acc=1.0*np.sum(ans_match_after_norms)/len(ans_match_after_norms)
                if acc<0.1:
                    return acc
            
        #model.train()
        acc=1.0*np.sum(ans_match_after_norms)/len(ans_match_after_norms)
        print("acc==",acc)
        return acc
    train_acc=get_acc(train_dataloader)
    test_acc=get_acc(test_dataloader)
    with open('result.txt','a') as f:
        content=f'{args.base_model_name} {args.num_virtual_token} {args.num_document_token} {args.train_bert} {args.lr} {args.dataset_name} {train_acc} {test_acc} {args.num_virtual_document}\n'
        f.write(content)

eval(1)


# import json
# processed_dataset_train=trainer.train_dataset
# trainer = SFTTrainer(
#     model=model,
#     train_dataset=test_dataset,
#     eval_dataset=train_dataset,
#     peft_config=peft_config,
#     dataset_text_field=None,
#     max_seq_length=max_seq_length,
#     tokenizer=tokenizer,
#     args=training_arguments,
#     #formatting_func=formatting_prompts_func,
#     callbacks=[MiniBatchCallback],
# )
# processed_dataset_test=trainer.train_dataset
# with open(train_file,'r') as f:
#     train_dataset=json.load(f)
# with open(test_file,'r') as f:
#     test_dataset=json.load(f)
# train_dataloader = DataLoader(
#     processed_dataset_train, shuffle=False, collate_fn=default_data_collator, batch_size=1, pin_memory=True
# )
# test_dataloader = DataLoader(
#     processed_dataset_test, shuffle=False, collate_fn=default_data_collator, batch_size=1, pin_memory=True
# )


# def get_acc(dataloader,dataset):
#     answer_string_in_prompt = "Answer:"
#     ans_match_after_norms=[]
#     idx=0
#     for step, batch in enumerate(tqdm(dataloader)):

#         with torch.no_grad():
#             outputs = model(**batch)

#         generated_output=tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        
#         print('pred==',generated_output)
#         generated_answers = []
#         for output in generated_output:
#             start = output.find(answer_string_in_prompt) + len(answer_string_in_prompt)
#             response = output[start:].strip()
#             generated_answers.append(response)

#         answers=dataset[idx]['completion']
#         idx+=1
#         answers=[answers]
#         ans_match_after_norm: bool = are_answers_matching(generated_answers[0], answers)
#         ans_match_after_norms.append(ans_match_after_norm)

#         print('generated answers==',generated_answers[0])
#         print('answers==',answers)

#     import numpy as np
#     acc=1.0*np.sum(ans_match_after_norms)/len(ans_match_after_norms)
#     print("acc==",acc)
#     return acc


# train_acc=get_acc(train_dataloader,train_dataset)
# test_acc=get_acc(test_dataloader,test_dataset)


# with open('result.txt','a') as f:
#     content=f'{args.base_model_name} {args.num_virtual_token} {args.percentage} {train_acc} {test_acc}\n'
#     f.write(content)