from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType,PeftModelForCausalLM,PromptTuningInit, PromptTuningConfig, PeftModel
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import argparse
import os
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments,
    AutoConfig
)
from transformers import TrainerCallback, TrainerState, TrainerControl
import numpy as np
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--base_model_name",
        type=str,
        default="meta-llama/Meta-Llama-3.1-8B-Instruct",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default="gold_only_reverse",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="nq_no_ret",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--percentage",
        type=float,
        default=0,
    )
    parser.add_argument(
        "--num_virtual_token",
        type=int,
        default=15,
    )
    parser.add_argument(
        "--num_document_token",
        type=int,
        default=3,
    )
    
    parser.add_argument(
        "--num_virtual_document",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=1e-5,
    )
    parser.add_argument(
        "--total_epoch",
        type=int,
        default=5,
    )
    parser.add_argument(
        "--train_bert",
        type=str2bool,
        default=True,
    )
    parser.add_argument(
        "--lora_r",
        type=int,
        default=64,
    )
    parser.add_argument(
        "--device",
        type=str,
        default="1,2,3",
    )
    return parser.parse_args()
import random
import numpy as np
import torch

def seed_everything(seed: int):
    """
    固定随机数种子以确保可重复性。

    参数:
        seed (int): 随机数种子
    """
    random.seed(seed)  # Python 随机数种子
    np.random.seed(seed)  # NumPy 随机数种子
    torch.manual_seed(seed)  # PyTorch 随机数种子
    torch.cuda.manual_seed(seed)  # GPU 随机数种子
    torch.cuda.manual_seed_all(seed)  # 适用于所有 GPU
    torch.backends.cudnn.deterministic = True  # 确保确定性
    torch.backends.cudnn.benchmark = False  # 禁用基准，以提高可重复性

seed_everything(42)

args=parse_args()
# from get_lr import get_lr
# args.lr=get_lr(args.base_model_name,args.dataset_name,args.num_virtual_token)
# print("lr==",args.lr)
args.output_dir=f'/home/somebody/codes/RAGE/RPrefix_tuning/checkpoints_prefix/lora_DP1/{args.base_model_name}/{args.dataset_name}/{args.train_bert}/{args.lr}/virtual_token_{args.num_virtual_token}/document_token_{args.num_document_token}/num_virtual_document_{args.num_virtual_document}'
if args.dataset_name=='gold_only_reverse':
    args.num_document_token=1
if args.base_model_name=='lmsys/vicuna-7b-v1.5':
    model_name_or_path='/data/somebody/data/huggingface_cache/hub/models--lmsys--vicuna-7b-v1.5/snapshots/3321f76e3f527bd14065daf69dad9344000a201d'
else:
    model_name_or_path = args.base_model_name

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
os.environ['CUDA_LAUNCH_BLOCKING']="1"


tokenizer_name_or_path = model_name_or_path

dataset_name=args.dataset_name
from datasets import load_dataset
# dataset = load_dataset("financial_phrasebank", "sentences_allagree")
# dataset = dataset["train"].train_test_split(test_size=0.1)
# dataset["validation"] = dataset["test"]
# del dataset["test"]

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,padding_side="left", truncation_side="left",model_max_length=4096)
tokenizer.pad_token = tokenizer.eos_token 
llama_tokenizer=AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf',padding_side="left", truncation_side="left",model_max_length=4096)
tokenizer.chat_template=llama_tokenizer.chat_template
train_file=os.path.join(f'prompt_answer/train',dataset_name)+".json"
test_file=os.path.join(f'prompt_answer/test',dataset_name)+".json"
print(train_file)
print(test_file)
from prompt_dataset import QueryDataset,are_answers_matching
dataloader=QueryDataset(train_file)

train_dataset = load_dataset('json',data_files={"train":train_file,"test":test_file},split='train')
test_dataset = load_dataset('json',data_files={"train":train_file,"test":test_file},split='test')


#peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)
prompt_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=args.num_virtual_token,
    prompt_tuning_init_text="Answer the question based on the provided documents and feel free to ignore the irrelavant or distracting ones",
    tokenizer_name_or_path=model_name_or_path,
)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
model_config.max_seq_len=4096
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
    config=model_config,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map='auto',
)
from DPrompt import DPromptModel
lora_alpha = 16
lora_dropout = 0.1
lora_r = args.lora_r
lora_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "ffn"
    ],
)

from peft.utils import _prepare_prompt_learning_config
model=get_peft_model(model,lora_config)
print(model)
model.print_trainable_parameters()
prompt_config.base_model_name_or_path=model_name_or_path
prompt_config=_prepare_prompt_learning_config(prompt_config,model_config.to_dict())
model=DPromptModel(args.base_model_name,model, prompt_config,train_bert=args.train_bert,num_document_token=args.num_document_token,num_virtual_document=args.num_virtual_document, percentage=args.percentage, adapter_name='default')
model.print_trainable_parameters()

# from llm import LLM
# model=LLM(model_name_or_path,'gold_only',quantization_bits=4, model_max_length=4096)


num_train_epochs = args.total_epoch
gradient_accumulation_steps = 1
optim = "paged_adamw_32bit"
save_strategy = "epoch"
learning_rate = args.lr
lr_scheduler_type = "linear"
warmup_ratio = 0.03
logging_steps = 1
prediction_loss_only = False
eval_steps = 1
bf16 = True


def eval(epoch):
    from prompt_dataset import QueryDataset
    train_dataloader=QueryDataset(data_path=train_file)
    test_dataloader=QueryDataset(data_path=test_file)
    def get_acc(dataloader):
        answer_string_in_prompt = "Answer:"
        ans_match_after_norms=[]
        idx=0
        for step, batch in enumerate(tqdm(dataloader)):
            #model.eval()
            prompts = batch['prompt']
            answers=batch['answers']
            if epoch==0:
                with torch.cuda.amp.autocast():
                    inputs = tokenizer(
                        prompts, 
                        padding=True, 
                        truncation=True, 
                        max_length=4096, 
                        return_tensors="pt",
                    ).to(model.device)
                    
                    generated_ids = model.generate(
                        **inputs,
                        do_sample=False,
                        max_new_tokens=15,
                        repetition_penalty=1.1,
                        pad_token_id=tokenizer.eos_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                    )
                    generated_output=tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            else:
                generated_output=model.generate_answer(prompts,tokenizer)
            
            print('pred==',generated_output)
            generated_answers = []
            for output in generated_output:
                start = output.find(answer_string_in_prompt) + len(answer_string_in_prompt)
                response = output[start:].strip()
                generated_answers.append(response)

            for i in range(len(generated_answers)):
                answer=[answers]
                ans_match_after_norm: bool = are_answers_matching(generated_answers[0], answer)
                ans_match_after_norms.append(ans_match_after_norm)

                print('generated answers==',generated_answers[i])
                print('answers==',answer)

            if len(ans_match_after_norms)>=30:
                acc=1.0*np.sum(ans_match_after_norms)/len(ans_match_after_norms)
                if acc<0.1:
                    return acc
            
        #model.train()
        acc=1.0*np.sum(ans_match_after_norms)/len(ans_match_after_norms)
        print("acc==",acc)
        return acc
    train_acc=get_acc(train_dataloader)
    test_acc=get_acc(test_dataloader)
    with open('result.txt','a') as f:
        content=f'{args.base_model_name} {args.num_virtual_token} {args.num_document_token} {args.train_bert} {args.lr} {args.dataset_name} {train_acc} {test_acc} {args.num_virtual_document} {epoch}\n'
        f.write(content)

class MiniBatchCallback(TrainerCallback):
    def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        # 打印当前step的loss
        # print(args)
        # print(state)
        # print(control)
        # print(kwargs)
        if len(state.log_history)==0:
            return
        if 'loss' in state.log_history[-1]:  # 确保有loss信息
            print(f"Step: {state.global_step}, Loss: {state.log_history[-1]['loss']}")
            if state.epoch % 1 == 0:
                pass
#eval(0)

training_arguments = TrainingArguments(
    output_dir='/home/somebody/codes/RAGE/RPrefix_tuning/checkpoints_prefix/test',
    seed=42,
    num_train_epochs=num_train_epochs,
    auto_find_batch_size=4,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_strategy=save_strategy,
    save_steps=4,
    learning_rate=args.lr,
    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,
    logging_strategy="steps",
    logging_steps=logging_steps,
    logging_dir='logs',
    prediction_loss_only=prediction_loss_only,
    eval_steps=eval_steps,
    bf16=bf16,
    report_to='wandb',
    logging_first_step=True,
    label_smoothing_factor=0.1
)
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    dataset_text_field=None,
    max_seq_length=4096,
    tokenizer=tokenizer,
    args=training_arguments,
    #formatting_func=formatting_prompts_func,
    callbacks=[MiniBatchCallback],
)
trainer.train()
model.save_pretrained(args.output_dir)

#eval(1)
