import argparse
import json
import re, os
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
import numpy as np
import jsonlines
from tqdm import trange
import random
import torch

from data_processing.answer_extraction import *
from eval.eval_script import *

MAX_INT = sys.maxsize
INVALID_ANS = "[Invalid]"

def get_mean_logprobs(logprobs, rank=1):
    lp_lst = []
    for pairs in logprobs:
        for key in pairs.keys():
            item = pairs[key]
            if item.rank != rank: continue
            lp_lst.append(item.logprob)
    return sum(lp_lst) / len(lp_lst)

def chat_template_to_prompt(prompt_list):
    result = ""
    total_step = len(prompt_list)
    for i, message in enumerate(prompt_list):
        result += ('<|im_start|>' + message['role'] +
                   '\n' + message['content'])
        if i+1 != total_step:
            result += '<|im_end|>\n'
        elif message['role'] == 'user':
            result += '<|im_end|>\n<|im_start|>assistant\n'
    return result

def jsonl_test(model, data, template='cot', start=0, end=MAX_INT, temperature=0.7, top_p=0.95, passk=1, tensor_parallel_size=1, seed=0, gpu=1):
    
    # Process prompt templates
    if template == 'normal':
        problem_prompt = (
            "Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Response:\n"
        )
    elif template == 'cot':
        problem_prompt = (
            "Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
        )
    elif template == 'ds-math':
        problem_prompt = (
            "{instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}."
        )
    elif template == 'intern-v1s':
        chat_prompt = [{"role": "user", "content": "Provide your best guess and the probability that it is correct (0.0 to 1.0) for the following question. Give ONLY the guess and probability, no other words or explanation. For example:\n\nGuess: <most likely guess, as short as possible; not a complete sentence, just the guess!>\n Probability: <the probability between 0.0 and 1.0 that your guess is correct, without any extra commentary whatsoever; just the probability!>\n\nThe question is: {instruction}"}]
        problem_prompt = (
            chat_template_to_prompt(chat_prompt)
        )
        print(problem_prompt)
    elif template == 'intern':
        chat_prompt = [{"role": "user", "content": "Problem:\n{instruction}\nLet's think step by step\nSolution:\n"}]
        problem_prompt = (
            chat_template_to_prompt(chat_prompt)
        )
        print(problem_prompt)
    else:
        raise NotImplementedError(f"Template {template} feature is not implemented yet.")
    
    # Read datasets
    instructions, answers = [], []
    data_path = f"./datasets/{data}.jsonl"
    print('dataset ====', data_path)
    with open(data_path, "r+", encoding="utf8") as f:
        for idx, item in enumerate(jsonlines.Reader(f)):
            ins = problem_prompt.format(instruction=item["input"])
            instructions.append(ins)
            ans = item['target']
            answers.append(ans)
    instructions = instructions[start: end]
    answers = answers[start: end]
    print('length ====', len(instructions))
    
    # Process sampling parameters
    stop_tokens = ["Instruction:", "Instruction", "Response:", "Response", '<|im_end|>',]
    if passk == 1:
        print("Greedy Sampling Evaluation")
        sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=1024, stop=stop_tokens, logprobs=1)
        temperature = 0
        top_p = 1
    else:
        print("Random Sampling Evaluation")
        sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=1024, stop=stop_tokens, logprobs=1)
    print('sampleing =====', sampling_params)
    
    # Process results
    results = {
        "predict": [[] for i in range(len(answers))],
        "answer": answers,
        "completion": [[] for i in range(len(answers))],
        "cumulative_logprob": [[] for i in range(len(answers))],
        "mean_logprob": [[] for i in range(len(answers))],
        "prompt": instructions,
        'temperature': temperature,
        'top_p': top_p,
        "accuracy": [[] for i in range(passk)]
    }
    llm = LLM(model=model, tensor_parallel_size=tensor_parallel_size, download_dir="../cache", gpu_memory_utilization=gpu, trust_remote_code=True)
    for i in trange(passk):
        completions = []
        cumulative_logprobs = []
        mean_logprobs = []
        for response in llm.generate(instructions, sampling_params):
            generated_text = response.outputs[0].text
            mean_logprobs.append(get_mean_logprobs(response.outputs[0].logprobs))
            cumulative_logprob = response.outputs[0].cumulative_logprob
            cumulative_logprobs.append(cumulative_logprob)
            completions.append(generated_text)
                
        for idx, (completion, c_logprob, m_logprob, answer) in enumerate(zip(completions, cumulative_logprobs, mean_logprobs, answers)):
            y_pred = extract_answer(completion)
            results["cumulative_logprob"][idx].append(c_logprob)
            results["mean_logprob"][idx].append(m_logprob)
            results["completion"][idx].append(completion)
            results["predict"][idx].append(INVALID_ANS if y_pred is None else y_pred)
            if y_pred != None:
                results["accuracy"][i].append(math_equal(y_pred, answer))
            else:
                results["accuracy"][i].append(False)
    
    passk_result = np.sum(np.array(results["accuracy"]), axis=0)
    passk_acc = (passk_result >= 1).mean()
    print(passk_result, passk_acc)
    
    model_name = os.path.basename(os.path.normpath(model))
    result_path = f"./results/{model_name}/{data}"
    if not os.path.exists(result_path): os.makedirs(result_path)
    with open(f"{result_path}/{data}-{template}-{temperature}-{top_p}-{passk}-{seed}.json", 'w', encoding='utf-8') as fw:
        json.dump(results, fw, ensure_ascii=False, indent=4)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)
    parser.add_argument('--data', choices=['gsm8k', 'gsmhardv2', "MATH", "AIME_1983_2024", "Odyssey", "OlympiadBench"], type=str, help='dataset')
    parser.add_argument("--template", type=str, choices=['normal', 'cot', 'ds-math', 'intern', 'intern-v1s'], default="cot")
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=MAX_INT)
    parser.add_argument("--batch_size", type=int, default=60)
    parser.add_argument("--passk", type=int, default=1)
    parser.add_argument("--tensor_parallel_size", type=int, default=1)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--gpu", type=float, default=0.9)
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    acc = jsonl_test(
        model=args.model, 
        data=args.data, 
        template=args.template, 
        start=args.start, 
        end=args.end, 
        temperature=args.temperature,
        top_p=args.top_p,
        passk=args.passk, 
        tensor_parallel_size=args.tensor_parallel_size,
        seed=args.seed,
        gpu=args.gpu
    )