import json
import json
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import re
import importlib.util
import os
import argparse
import vllm.envs as envs
import random
import time
from datetime import datetime
from tqdm import tqdm
from utils.utils import set_seed, load_jsonl, save_jsonl, construct_prompt, readjsonl2list
from utils.parser import *
from utils.data_loader import load_data
from utils.math_normalization import *
from utils.grader import *
import pickle
from math import comb
import numpy as np

# envs.VLLM_HOST_IP="0.0.0.0" or "127.0.0.1"

def parse_list(arg):
    return arg.split(',')

def save_completions(completions, filepath):
    with open(filepath, 'wb') as file:
        pickle.dump(completions, file)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path', type=str, default="./", help="model dir")
    parser.add_argument('--n_sampling', type=int, default=1, help="n for sampling")
    parser.add_argument("--k", type=int, default=1, help="Value of k for pass@k calculation")
    parser.add_argument("--data_dir", default="./data", type=str)
    parser.add_argument('--data_name', type=str, default="math", help='identify how to extract answer')
    parser.add_argument("--split", default="test", type=str)
    parser.add_argument('--start_idx', type=int, default=0, help="data[start:end]")
    parser.add_argument('--end_idx', type=int, default=-1, help="data[start:end], if -1, data[start:]")
    parser.add_argument("--temperature", default=0, type=float)
    parser.add_argument("--max_tokens", default=2048, type=int)
    parser.add_argument("--prompt_type", default="qwen-base", type=str)
    parser.add_argument("--num_rounds", default=0, type=int)
    parser.add_argument("--prompt_file_path", default="./prompts", type=str)
    parser.add_argument("--surround_with_messages", action="store_true")
    parser.add_argument("--use_few_shot", action="store_true")
    parser.add_argument("--output_dir", default="./outputs_v6", type=str)
    parser.add_argument('--stop', type=parse_list)
    parser.add_argument("--top_p", default=1, type=float)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--dtype", default='auto', type=str)
    parser.add_argument("--completions_save_dir", default='./completions', type=str)
    # parser.add_argument("--use_qwen_check", action="store_true")
    args = parser.parse_args()
    
    args.top_p = 1 if args.temperature == 0 else args.top_p # top_p must be 1 when using greedy 
    print(f"current stop list: {args.stop}")
    return args

def get_conversation_prompt_by_messages(tokenizer, messages):
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    #breakpoint()
    return text



def infer(args):
    model_name_or_path = args.model_name_or_path
    print(f"current eval model: {model_name_or_path}")
    
    n_sampling = args.n_sampling
    factor = 1
    for i in range(2, 65):
        if n_sampling % i == 0:
            factor = i
    generation_epoch = n_sampling // factor
    print(f"use n = {factor}, generation epoch is: {generation_epoch}")
    if args.temperature > 0.0: 
        do_sample = False
    else:
        do_sample = True
    sampling_params = SamplingParams(temperature=args.temperature, 
                                     max_tokens=args.max_tokens, 
                                     n=1,
                                     top_p=args.top_p,
                                     stop=["</summary>"],
                                     seed=args.seed
                                     )
    
    examples = load_data(args.data_name, args.split, args.data_dir)
    if args.end_idx == -1:
        args.end_idx = len(examples)
    examples = examples[args.start_idx:args.end_idx]
    
    dt_string = datetime.now().strftime("%m-%d_%H-%M")
    model_name = "/".join(args.model_name_or_path.split("/")[-3:])
    out_file_prefix = f'{args.split}_{args.prompt_type}_t{args.temperature}_p{args.top_p}_seed{args.seed}'
    out_file = f'{args.output_dir}/{model_name}/{args.data_name}/{out_file_prefix}_k{args.n_sampling}_s{args.start_idx}_e{args.end_idx}_l{args.max_tokens}_r{args.num_rounds}.jsonl'
    out_log = f'{args.output_dir}/{model_name}/{args.data_name}/{out_file_prefix}_k{args.n_sampling}_s{args.start_idx}_e{args.end_idx}_l{args.max_tokens}_r{args.num_rounds}.txt' 

    
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)

    if os.path.exists(out_file):
        print(f"Completely same name file({out_file}) exist, skip generation, save file and check correct")
        file_outputs = readjsonl2list(out_file)
    else:
        os.makedirs(f'{args.output_dir}/{model_name}/{args.data_name}', exist_ok=True)
        os.makedirs(f'{args.completions_save_dir}/{model_name}/{args.data_name}', exist_ok=True)
    
        available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
        if len(available_gpus) == 1:
            envs.VLLM_HOST_IP="0.0.0.0" or "127.0.0.1"
        print(f"available_gpus: {available_gpus}")
        prompt_batch = []
        for example in tqdm(examples, total=len(examples)):
            # parse question and answer
            question = parse_question(example, args.data_name)
        
            if args.use_few_shot:
                cur_prompt = few_shot_prompt + question_format.format(question=question)
            else:
                cur_prompt = question_format.format(question=question)
            if args.surround_with_messages:
                messages = [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": cur_prompt}
                ]
                cur_prompt = get_conversation_prompt_by_messages(tokenizer=tokenizer, messages=messages)
            prompt_batch.append(cur_prompt)
        print(prompt_batch[0])
    
        llm = LLM(model=model_name_or_path, 
                tensor_parallel_size=len(available_gpus), 
                trust_remote_code=True, 
                swap_space=60,
                gpu_memory_utilization=0.92,
                seed=args.seed
                )
    
        file_outputs = []
        
        for cur_generation_epoch in range(generation_epoch):
        
            remain_prompts = [prompt for prompt in prompt_batch for _ in range(factor)]
            remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)]
            end_prompts = []
            start_time = time.time()
            for epoch in range(args.num_rounds+1):
                #breakpoint()
                print("=" * 50, "Epoch", epoch)
                current_prompts = remain_prompts
                if len(current_prompts) == 0:
                    break
                print(f"Remaining: {len(current_prompts)}")
                prompts = [item[1] for item in current_prompts]
                outputs = llm.generate(prompts, sampling_params)
            
                outputs = sorted(outputs, key=lambda x: int(x.request_id)) # sort outputs by request_id
                outputs = [output.outputs[0].text for output in outputs]
                assert len(outputs) == len(current_prompts)


                # append response to query
                remain_prompts = []
                #breakpoint()
                for (i, query), output in zip(current_prompts, outputs):
                    output = output.rstrip() + "\n</summary>"
                    if "<summary>" in output:
                        summary = output.split("<summary>")[-1].split("</summary>")[0].strip()
                    else:
                        summary = output.split("</think>")[-1].split("</summary>")[0].strip()
                    thought = output.split("<think>")[-1].split("</think>")[0].strip() 
                    if epoch > 0:
                        if "<critique>" in output:
                            critique = output.split("<critique>")[-1].split("</critique>")[0].strip()
                        else:
                            critique = output.split("<think>")[0].strip()
                        query = query.split("<summary>")[0] + "<summary>\n" + summary + "\n</summary>\n\n"
                    else:
                        query = query +  "<summary>\n" + summary + "\n</summary>\n\n"
                        critique = ""
                    if epoch == 0:
                        
                        remain_prompts.append((i, query))
                    else:
                        if critique.split("Overall judgment:")[-1].strip() == "Correct":
                            end_prompts.append((i, query))
                        else:
                            remain_prompts.append((i, query)) 
            

                if epoch == args.num_rounds:
                    break


        
            end_prompts.extend(remain_prompts)
            # sort by idx
            end_prompts = sorted(end_prompts, key=lambda x: x[0])
            codes = [prompt.strip() for _, prompt, thoughts, summaries, critiques in end_prompts]

            # extract preds
            time_use = time.time() - start_time

            for i in range(len(examples)):
                d = examples[i]
                question = parse_question(d, args.data_name)
                generated_responses = codes[i*factor: (i+1)*factor]
                if cur_generation_epoch == 0:
                    file_outputs.append({
                        "question": question,
                        "generated_responses": generated_responses,
                    })
                    if "id" in d:
                        file_outputs[i]["id"] = d["id"]
                    if "source" in d:
                        file_outputs[i]["source"] = d["source"]
                else:
                    file_outputs[i]['generated_responses'] += generated_responses
                    aux = file_outputs[i]["generated_thoughts"]
                    new_aux = []
        print("llm generate done")
        print(len(file_outputs))

    pass_at_k_list = []

    k = args.k

    correct_cnt = 0
    #pre_correct_cnt = 0
    for i in tqdm(range(len(examples)), "check correct..."):
        d = examples[i]
        gt_cot, gt_ans = parse_ground_truth(d, args.data_name)
        generated_responses = file_outputs[i]['generated_responses']
        generated_answers = [extract_answer(generated_response, args.data_name) for generated_response in generated_responses]
        is_correct_list = [check_is_correct(generated_answer, gt_ans) for generated_answer in generated_answers]
        is_correct = any(is_correct_list)
        if is_correct:
            correct_cnt += 1
        
        
        file_outputs[i]['generated_answers'] = generated_answers
        file_outputs[i]['gold_answer'] = gt_ans
        file_outputs[i]['is_correct'] = is_correct

        if len(is_correct_list) > 1:
            correct_answers = sum(is_correct_list)
            n = len(generated_answers)
            if correct_answers > 0:
                if n - correct_answers < k:
                    pass_at_k = 1
                else:
                    pass_at_k = 1 - (comb(n - correct_answers, k) / comb(n, k))
                pass_at_k_list.append(pass_at_k)
            else:
                pass_at_k_list.append(0)
        
    
    temp_out_file = out_file + ".tmp"
    with open(temp_out_file, 'w', encoding='utf-8') as f:
        count = 0
        for d in tqdm(file_outputs, "writing generation to jsonl file..."):
            f.write(json.dumps(d, ensure_ascii=False))
            f.write("\n")
            count += 1
            if count % 100 == 0:
                f.flush()
        f.flush()
    os.rename(temp_out_file, out_file)
    
    print(f"correct cnt / total cnt: {correct_cnt}/{len(examples)}")
    print(f"Acc: {correct_cnt / len(examples):.4f}")
    

    with open(out_log, 'a', encoding='utf-8') as file:
        file.write(f"correct cnt / total cnt: {correct_cnt}/{len(examples)}" + '\n')
        file.write(f"Acc: {correct_cnt / len(examples):.4f}" + '\n\n')

    k = args.k


    #print(len(pass_at_k_list_pre))
    if pass_at_k_list:
        average_pass_at_k = sum(pass_at_k_list) / len(pass_at_k_list)
        print(f"Pass@{k}: {sum(pass_at_k_list)}/{len(pass_at_k_list)} = {average_pass_at_k:.4f}")
        with open(out_log, 'a', encoding='utf-8') as file:
            file.write(f"Pass@{k}: {sum(pass_at_k_list)}/{len(pass_at_k_list)} = {average_pass_at_k:.4f}" + "\n")
    else:
        print(f"Pass@1: {correct_cnt}/{len(examples)} = {correct_cnt / len(examples):.4f}")
        with open(out_log, 'a', encoding='utf-8') as file:
            file.write(f"Pass@1: {correct_cnt}/{len(examples)} = {correct_cnt / len(examples):.4f}" + "\n")
    i = 0

    
    

if __name__ == "__main__":
    args = parse_args()
    set_seed(args.seed)
    infer(args)
