import argparse
import json
import os
script_dir = os.path.dirname(__file__)
parent_dir = os.path.dirname(script_dir)
import time
import multiprocessing as mp

import shortuuid
from fastchat.llm_judge.common import load_questions
from fastchat.model import get_conversation_template
from tqdm import tqdm

from ea_model import EaModel
from kv_cache import initialize_past_key_values
from utils import *

import random
import torch
import numpy as np


def extract_code(model_output: str):
    outputlines = model_output.split("\n")
    indexlines = [i for i, line in enumerate(outputlines) if "```" in line]
    if len(indexlines) < 2: return ""
    return "\n".join(outputlines[indexlines[-2] + 1 : indexlines[-1]])

def get_qwen_question(question_dict):
    prompt = "You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests. You will NOT return anything except for the program.\n\n"
    prompt += f"Question:\n{question_dict['question_content']}\n\n"
    FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters."
    FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT."
    if 'starter_code' in question_dict:
        prompt += f"{FORMATTING_MESSAGE_WITH_STARTER_CODE}\n"
        prompt += f"```python\n{question_dict['starter_code']}\n```\n\n"
    else:
        prompt += f"{FORMATTING_WITHOUT_STARTER_CODE}\n\n"
        prompt += f"```python\n# YOUR CODE HERE\n```\n\n"
    return prompt

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def run_eval(
        base_model_path,
        ea_model_path,
        model_id,
        question_file,
        question_begin,
        question_end,
        answer_file,
        max_new_tokens,
        num_choices,
        num_gpus_per_model,
        num_gpus_total,
        max_gpu_memory,
        temperature,
        args
):
    if "humaneval" in args.bench_name:
        try:
            from evalplus.data import get_human_eval_plus
            questions = []
            for task_id, problem in get_human_eval_plus().items():
                problem["question_id"] = task_id
                problem["problem"] = problem["prompt"]
                problem["answer"] = problem["canonical_solution"]
                questions.append(problem)
        except Exception as e:
            print(f"Failed to load humaneval_plus dataset: {e}")
            print(f"Use local file {question_file} instead")
            questions = load_questions(question_file, question_begin, question_end)
    else:
        questions = load_questions(question_file, question_begin, question_end)


    for i, q in enumerate(questions):
        if "question_id" not in q:
            q["question_id"] = i

    os.makedirs(os.path.dirname(answer_file), exist_ok=True)
    if os.path.exists(answer_file):
        if args.pool_type == "none":
            with open(answer_file, "r") as f: pass
        else:
            with open(answer_file, "r") as f:
                existing_answers = set(json.loads(l)["question_id"] for l in f)
            print(f"Already exists {answer_file} : {len(existing_answers)} answers")
            questions = [q for q in questions if q["question_id"] not in existing_answers]

    num_proc = min(args.num_gpus_total, torch.cuda.device_count(), len(questions))
    print(f"Load {len(questions)} questions, using {num_proc} processes")

    if num_proc > 1:
        mp.set_start_method('spawn', force=True)
        manager = mp.Manager()
        lock = manager.Lock() 
        data_subsets = [questions[i::num_proc] for i in range(num_proc)]
        processes = []
        print(f"len(data_subsets)={len(data_subsets)}")
        for rank in range(num_proc):
            p = mp.Process(target=get_model_answers, args=(
                f"cuda:{rank}", 
                base_model_path, 
                ea_model_path, 
                model_id, 
                data_subsets[rank], 
                answer_file, 
                max_new_tokens,
                num_choices,
                num_gpus_per_model,
                max_gpu_memory,
                temperature,
                args,
                lock
            ))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()
    else:
        get_model_answers(
            "auto", 
            base_model_path, 
            ea_model_path, 
            model_id, 
            questions, 
            answer_file, 
            max_new_tokens,
            num_choices,
            num_gpus_per_model,
            max_gpu_memory,
            temperature,
            args,
            None
        )


@torch.inference_mode()
def get_model_answers(
    device_map,
    base_model_path,
    ea_model_path,
    model_id,
    questions,
    answer_file,
    max_new_tokens,
    num_choices,
    num_gpus_per_model,
    max_gpu_memory,
    temperature,
    args,
    lock
):
    model = EaModel.from_pretrained(
        base_model_path=base_model_path,
        ea_model_path=ea_model_path,
        total_token=args.total_token,
        depth=args.depth,
        top_k=args.top_k,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map=device_map,
        use_eagle3=True,
        confidence_loss_type=args.confidence_loss_type
    )
    model.eval()

    tokenizer = model.get_tokenizer()

    think_step_split_map = {
        "\n\n": "\n",
        "\n\n\n": "\n",
        ".\n\n": ".\n",
        ".\n\n\n": ".\n",
        " \n\n": "\n",
        " \n\n\n": "\n"
    }
    contrast_discourse_markers = [
        "Wait", "Hmm", "Yet", "Instead", "Alternatively", "However", "But",
    ]
    discourse_markers = [
        "Wait", "Hmm", "Yet", "Instead", "Alternatively", "However", "But",
        "Therefore", "Thus", "So", "Well", "Okay", "Right", 
        "Alright", "First", "Next", "Let", "Let's"
    ]
    if args.split_type == "contrast_discourse_markers":
        think_step_split_map.clear()
        for t in contrast_discourse_markers:
            think_step_split_map[t] = args.think_end_token
    elif args.split_type == "all_discourse_markers":
        think_step_split_map.clear()
        for t in discourse_markers:
            think_step_split_map[t] = args.think_end_token
    elif args.split_type == "paragraph":
        if not args.map_split_ids:
            for k, v in think_step_split_map.items():
                think_step_split_map[k] = k
    think_step_split_map_ids = {}
    for k, v in think_step_split_map.items():
        k_ids = tokenizer.encode(k, add_special_tokens=False)
        if len(k_ids) != 1: continue
        v_ids = tokenizer.encode(v, add_special_tokens=False)
        if len(v_ids) != 1: 
            print(f"Warn: think_step_split_map target {v!r} id {v_ids} is not unique")
        print(f"think_step_split_tokens {k!r}({k_ids}) to {v!r}({v_ids})")
        think_step_split_map_ids[k_ids[0]] = v_ids[0]

    model.eval()
    print('Check model training state:', model.training)

    for question in tqdm(questions):
        choices = []
        for i in range(num_choices):
            torch.manual_seed(i)
            if args.bench_name in ["gsm8k", "math", "gpqa", "aime"]:
                sys_p = "Please reason step by step, and put your final answer within \\boxed{}."
            elif "humaneval" in args.bench_name:
                sys_p = "You are an expert Python programmer."
            else:
                sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
            conv = [{
                "role": "system",
                "content": sys_p
            }]
            turns = []
            idxs = []
            new_tokens = []
            wall_time = []
            accept_length_lists = []
            confidence_list = []
            exit_score = None
            
            if "question" in question:
                question["turns"] = [question["question"]]
            if "problem" in question:
                question["turns"] = [question["problem"]]
            if "id" in question:
                question["question_id"] = question["id"]

            for j in range(len(question["turns"])):
                qs = question["turns"][j]
                conv.append({"role": "user", "content": qs})
                conversation = tokenizer.apply_chat_template(
                    conv,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=args.enable_thinking,
                )
                print(f"input: {conversation!r}")
                
                input_ids = tokenizer(
                    conversation,
                    return_tensors="pt",
                    max_length=2048,
                    add_special_tokens=False,
                ).input_ids.to(model.device)

                torch.cuda.synchronize()
                start_time = time.time()
                output_ids, new_token, idx, accept_length_list, exit_score, confidence_list, paragraph_states = model.eagenerate(
                    input_ids,
                    max_length=max_new_tokens,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=args.top_p,
                    top_k=args.top_k,
                    repetition_penalty=args.repetition_penalty,
                    think_step_split_map_ids=think_step_split_map_ids,
                    log=True,
                    input_text=qs,
                    enable_think_exit=args.enable_think_exit,
                    exit_threshold=args.exit_threshold,
                    min_think=args.min_think,
                    min_paragraph=args.min_paragraph,
                    window_size=args.window_size,
                    pool_type=args.pool_type,
                    stop_think_prompt_ids=tokenizer.encode(args.stop_think_prompt) if args.stop_think_prompt else None,
                    stop_prob_threshold=args.stop_prob_threshold,
                    split_type=args.split_type
                )
                torch.cuda.synchronize()
                total_time = time.time() - start_time
                output_ids = output_ids[0][len(input_ids[0]):]
                output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
                print(output)
                if "answer" in question:
                    print(f"gold_answer: {question['answer']}")

                for special_token in tokenizer.special_tokens_map.values():
                    if isinstance(special_token, list):
                        for special_tok in special_token:
                            output = output.replace(special_tok, "")
                    else:
                        output = output.replace(special_token, "")
                output = output.strip()

                turns.append(output)
                idxs.append(int(idx))
                new_tokens.append(int(new_token))
                wall_time.append(total_time)
                accept_length_lists += accept_length_list
                conv.append({
                    "role": "assistant",
                    "content": output
                })
            choices.append({
                "index": i, 
                "turns": turns, 
                "idxs": idxs, 
                "exit_score": exit_score,
                "paragraph_states": paragraph_states,
                "new_tokens": new_tokens, 
                "wall_time": wall_time, 
                "accept_length": accept_length_lists, 
                "confidence_list": confidence_list
            })

        # Dump answers
        ans_json = {
            "question_id": question["question_id"],
            "answer_id": shortuuid.uuid(),
            "model_id": model_id,
            "choices": choices,
            "tstamp": time.time(),
        }
        if "answer" in question:
            ans_json["gold_answer"] = question["answer"]
        if "humaneval" in args.bench_name:
            ans_json["task_id"] = question["question_id"]
            if choices[0]["turns"]:
                ans_json["solution"] = extract_code(choices[0]["turns"][-1])
            else:
                ans_json["solution"] = ""
        if "zebralogic" in args.bench_name:
            ans_json["id"] = question["id"]
            ans_json["solution"] = question["solution"]
            ans_json["size"] = question["size"]
        if "leetcode" in args.bench_name:
            ans_json["task_id"] = question["task_id"]
            ans_json["difficulty"] = question["difficulty"]
            ans_json["prompt"] = question["prompt"]
            ans_json["test"] = question["test"]
            ans_json["entry_point"] = question["entry_point"]

        if lock:
            with lock:
                with open(answer_file, "a") as fout:
                    fout.write(json.dumps(ans_json) + "\n")
        else:
            with open(answer_file, "a") as fout:
                fout.write(json.dumps(ans_json) + "\n")


def reorg_answer_file(answer_file):
    """Sort by question id and de-duplication"""
    answers = {}
    with open(answer_file, "r") as fin:
        for l in fin:
            qid = json.loads(l)["question_id"]
            answers[qid] = l

    qids = sorted(list(answers.keys()))
    with open(answer_file, "w") as fout:
        for qid in qids:
            fout.write(answers[qid])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ea-model-path", type=str, required=True)
    parser.add_argument("--base-model-path", type=str, required=True)
    parser.add_argument("--load-in-8bit", action="store_true")
    parser.add_argument("--model-id", type=str, required=True)
    parser.add_argument(
        "--bench-name",
        type=str,
        default="mt_bench",
        help="The name of the benchmark question set.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end", type=int, help="A debug option. The end index of questions."
    )
    parser.add_argument("--answer-file", type=str, help="The output answer file.")
    parser.add_argument("--output-dir", type=str, default=".")
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=1024,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--total-token",
        type=int,
        default=60,
        help="The total number of nodes in the draft tree",
    )
    parser.add_argument(
        "--depth",
        type=int,
        default=5,
    )
    parser.add_argument(
        "--top-k",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
        help="The number of GPUs per model.",
    )
    parser.add_argument(
        "--num-gpus-total", type=int, default=8, help="The total number of GPUs."
    )
    parser.add_argument(
        "--max-gpu-memory",
        type=str,
        help="Maxmum GPU memory used for model weights per GPU.",
    )

    parser.add_argument(
        "--temperature",
        type=float,
        default=0.6,
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.0,
    )
    parser.add_argument(
        "--top_k",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--repetition_penalty",
        type=float,
        default=1.0,
    )

    parser.add_argument(
        "--tree-choices",
        type=str,
        default="mc_sim_7b_63",
    )

    parser.add_argument(
        "--seed",
        type=int,
        default=42,
    )

    parser.add_argument(
        "--exit-threshold",
        type=float,
        nargs="+",
        default=[50],
    )
    parser.add_argument(
        "--stop-prob-threshold",
        type=float,
        default=0.0,
    )
    parser.add_argument(
        "--confidence-loss-type",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--think-end-token",
        type=str,
        default="</think>",
    )
    parser.add_argument(
        "--enable-thinking",
        action="store_true",
    )
    parser.add_argument(
        "--enable-think-exit",
        action="store_true",
    )
    parser.add_argument(
        "--min-think",
        type=int,
        default=100,
    )
    parser.add_argument(
        "--min-paragraph",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--window-size",
        type=int,
        default=100,
    )
    parser.add_argument(
        "--stop-think-prompt",
        type=str,
        default="",
    )
    parser.add_argument(
        "--map-split-ids",
        action="store_true",
    )
    parser.add_argument(
        "--repeat-times",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--pool-type",
        type=str,
        default="ewma",
    )
    parser.add_argument(
        "--split-type",
        type=str,
        default="paragraph",
    )

    args = parser.parse_args()

    setup_seed(args.seed)

    args.model_id = f"{args.model_id}-temperature{args.temperature}-max_len{args.max_new_tokens}"
    if args.enable_thinking and args.enable_think_exit:
        thresh_str = '_'.join([str(thresh) for thresh in args.exit_threshold])
        args.model_id += f"-threshold{thresh_str}_min{args.min_think}_{args.split_type}"
        if args.map_split_ids:
            args.model_id += "-map_split_ids"
        if args.min_paragraph:
            args.model_id += f"-paragraph_min{args.min_paragraph}"
        if args.pool_type == "ewma":
            args.model_id += "-ewma0.1"
        else:
            args.model_id += f"-pool_{args.pool_type}"

        if args.pool_type == "paragraph_mean":
            assert args.split_type == "paragraph"
    else:
        args.model_id += f"-baseline"

    for repeat_iter in range(args.repeat_times):

        question_file = f"{parent_dir}/benchmark/{args.bench_name}/question.jsonl"
        if args.answer_file:
            answer_file = args.answer_file
        else:
            answer_file = f"{args.output_dir}/{args.bench_name}/{args.model_id}{'-repeat' + str(repeat_iter) if repeat_iter else ''}.jsonl"

        print(f"Output to {answer_file}")

        run_eval(
            args.base_model_path,
            args.ea_model_path,
            args.model_id,
            question_file,
            args.question_begin,
            args.question_end,
            answer_file,
            args.max_new_tokens,
            args.num_choices,
            args.num_gpus_per_model,
            args.num_gpus_total,
            args.max_gpu_memory,
            args.temperature,
            args
        )

