import pickle
import torch
import time
import os
import datasets
from torch.utils.data import DataLoader,TensorDataset
from tqdm import tqdm
from argparse import ArgumentParser
import random
import numpy as np
from trl.data_utils import maybe_apply_chat_template
from transformers import AutoModelForCausalLM, AutoTokenizer

def extract_answer_from_dataset(text):
    if "####" not in text:
        return None
    return text.split("####")[0], text.split("####")[1].strip().replace(',', '')

def hyper_parameters():
    parser = ArgumentParser(description='test')
    parser.add_argument('--model_dir', type=str, default="./Qwen2.5-7B-Instruct")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--device', type=str, default="cuda:4")
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--output', type=str, default='./output')
    parser.add_argument('--dataset_dir', type=str, default="./gsm8k")
    opt = parser.parse_args()
    return opt

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def calculate_answer_confidence(model, tokenizer, input_ids, answers_len, attention_mask):
    outputs = model(input_ids=input_ids, attention_mask=attention_mask,logits_to_keep=int(torch.max(answers_len)+1))
    logits = outputs.logits
    probs = torch.nn.functional.log_softmax(logits, dim=-1)
    answer_probs = []
    for a in range(len(answers_len)):
        answer_prob=[]
        for i in range(answers_len[a]):
            token_id = input_ids[a, -i-1].item()  
            token_prob=probs[a, -i-2, token_id].item()
            answer_prob.append(token_prob)
        answer_probs.append(answer_prob)
    return answer_probs

if __name__ == '__main__':
    hps = hyper_parameters()
    set_seed(hps.seed)
    try:
        datas = datasets.load_dataset(hps.dataset_dir)['train']
    except:
        datas = datasets.load_from_disk(hps.dataset_dir)['train']

    tokenizer = AutoTokenizer.from_pretrained(hps.model_dir)
    device=hps.device
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token 
    system_prompt="A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>\n<answer> answer here </answer>." 
    def make_conversation(example):
        prompt = []
        if system_prompt is not None:
            prompt.append({"role": "system", "content": system_prompt})

        prompt.append({"role": "user", "content": example["question"]})
        return {"prompt": prompt}
    datas = datas.map(make_conversation)
    solution=[extract_answer_from_dataset(s)[1] for s in datas['solution']]
    input_texts=[]
    answers_len=[]
    for i in range(len(datas)):
        d={}
        d['prompt']=datas[i]["prompt"]
        d['prompt'].append({'content':'<answer> '+solution[i],'role':'assistant'})
        input_text = maybe_apply_chat_template(d, tokenizer)["prompt"]
        input_texts.append(input_text)
        answer=tokenizer(solution[i])
        answers_len.append(len(answer['input_ids']))
    answers_len=torch.tensor(answers_len)
    inputs = tokenizer(input_texts, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False)

    inputs['input_ids']=inputs['input_ids'].to(device)
    inputs['attention_mask']=inputs['attention_mask'].to(device)
    answers_len=answers_len.to(device)
    data = TensorDataset(inputs['input_ids'],inputs['attention_mask'],answers_len)


    train_dataloader = DataLoader(data, batch_size=hps.batch_size, shuffle=False)
    rewards = []
    acc=[]

    model = AutoModelForCausalLM.from_pretrained(hps.model_dir)
    conf = []
    model = model.to(device)
    model = model.half()
    model.eval()
    
    with torch.no_grad():
        begin=time.time()
        for batch in tqdm(train_dataloader):
            input_ids,attention_mask,answers_len=batch
            answer_confidence = calculate_answer_confidence(model, tokenizer, input_ids, answers_len, attention_mask)
            conf=conf+answer_confidence
        end=time.time()
    print(end-begin)
    combined = list(zip(datas, conf))
    
    if not os.path.exists(hps.output):
        os.makedirs(hps.output)
    
    with open(os.path.join(hps.output, "answer_confidence.pkl"), "wb") as f:
        pickle.dump(combined, f)
    answer_sorce=1-(conf-(sum(conf)/len(conf))**2)
    with open(os.path.join(hps.output, "answer_sorce.pkl"), "wb") as f:
        pickle.dump(answer_sorce)
