import argparse
import json
import os
import sys
import json
import numpy as np
from utils import extract_answer, set_seed, sample_indices, dataset_paths
from grader import grade_answer
sys.path.append("./sglang/python")
import sglang as sgl
from sglang import function, system, user, assistant, gen, RuntimeEndpoint
from sglang.lang.chat_template import get_chat_template


@function
def get_response(s, args, qid, model, speed, problem, ground_truth_answer, save_path):
    if speed == "fast":
        speed_prompt = "\nSolve the problem as quickly as possible."
    elif speed == "normal":
        speed_prompt = ""
        
    if "deepseek-r1" in model.lower() or "deepscaler" in model.lower():
        s += user(f"Please reason step by step, and put your final answer within \\boxed{{}}. {problem}{speed_prompt}")
    
    elif "qwen" in model.lower() and "instruct" in model.lower():
        s += system(f"Please reason step by step, and put your final answer within \\boxed{{}}")
        s += user(f"{problem}{speed_prompt}")
    
    forks = s.fork(args.num_sample)
    num_tokens_list = []
    solution_list = []
    answer_list = []
    correct_list = []
    for fork in forks:
        fork += assistant(gen("solution", args.max_tokens, temperature=args.temperature, top_p=args.top_p))
        num_tokens = fork.get_meta_info("solution")["completion_tokens"]
        num_tokens_list.append(num_tokens)
        solution_list.append(fork["solution"])
        answer = extract_answer(fork["solution"])
        if grade_answer(answer, ground_truth_answer):
            correct = 1
        else:
            correct = 0
        answer_list.append(answer)
        correct_list.append(correct)

    res = {
        "qid": qid,
        "speed": speed,
        "speed_prompt": speed_prompt,
        "problem": problem,
        "ground_truth_answer": ground_truth_answer,
        "max_tokens": args.max_tokens,
        "temperature": args.temperature,
        "top_p": args.top_p,
        "num_tokens_list": num_tokens_list,
        "answer_list": answer_list,
        "correct_list": correct_list,
        "solution_list": solution_list,
    }
    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(res, f, indent=4)
    print(f">>> Save to [{save_path}]")
    import pdb; pdb.set_trace()
    return


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--dataset", type=str, default='math_train')
    parser.add_argument("--num_inst", type=int, default=5000)
    parser.add_argument("--num_sample", type=int, default=4)
    parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-1.5B")
    parser.add_argument("--model_path", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
    parser.add_argument("--chat_template_type", type=str, default="deepseek-v3")
    parser.add_argument("--host", type=str, default="http://localhost:30000")
    parser.add_argument("--num_threads", type=int, default=96)
    parser.add_argument("--speed", type=str, default="normal")
    parser.add_argument("--max_tokens", type=int, default=8192)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=-1)
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    ###################
    ## Set seed
    ###################
    set_seed(args.seed)

    ###################
    ## Load dataset
    ###################
    input_path = dataset_paths[args.dataset]
    data = json.load(open(input_path,"r"))

    ###################
    ## Sample indices
    ###################
    sampled_indices = sample_indices(data, args.dataset, args.num_inst)
    if args.end == -1:
        args.end = len(data)

    ###################
    ## Load backend
    ###################
    backend = RuntimeEndpoint(args.host)

    save_dir = f"./data/generated/{args.model_name}_{args.dataset}_T={args.temperature}_top_p={args.top_p}_token={args.max_tokens}_{args.speed}_num={args.num_sample}/raw"
    input_dict_list = []
    for qid in sampled_indices:
        if args.start <= qid < args.end:
            save_path = os.path.join(save_dir, f"{qid}.json")
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            input_dict_list.append({
                "args": args,
                "qid": qid,
                "model": args.model_name,
                "speed": args.speed,
                "problem": data[qid]["problem"],
                "ground_truth_answer": data[qid]["ground_truth_answer"],
                "save_path": save_path,
            })
    
    ###################
    ## Get responses
    ###################
    chat_template = get_chat_template(args.chat_template_type)
    print(f">>> Chat template = {args.chat_template_type}")
    backend.chat_template = chat_template
    states = get_response.run_batch(
        input_dict_list,
        backend=backend, 
        num_threads=args.num_threads, 
        progress_bar=True,
    )
    os._exit(0)