import argparse
import json
import os
import sys
import json
import re
import random
from transformers import AutoTokenizer
import numpy as np
from utils import extract_answer, set_seed, sample_indices, dataset_paths
from grader import grade_answer
sys.path.append("./sglang/python")
import sglang as sgl
from sglang import function, system, user, gen, RuntimeEndpoint
from sglang.lang.chat_template import get_chat_template


@function
def conclude(s, args, qid, yid, model, tokenizer, problem, ground_truth_answer, original_num_tokens, original_solution, correct, history, chunks, max_tokens, save_path):
    final_k = next(
        (i for i, item in enumerate(history) if len(item["inner_correct_list"]) == 1),
        len(history) 
    )
    pdist = []
    for k in range(0, final_k):
        p = np.mean(history[k]['inner_correct_list'])
        pdist.append(p)
    
    consec = 0
    k_target = None
    for i, p in enumerate(pdist):
        if p >= args.threshold:
            consec += 1
            if consec >= 1:
                k_target = i
                break
        else:
            consec = 0
    
    if k_target is None:
        return
    
    partial_solution = history[k_target]["partial_solution"]
    index = partial_solution.find("<think>")

    if index != -1:
        partial_solution = partial_solution[index:]
    else:
        partial_solution = ""  # or s, or raise an error, depending on your need
    
    partial_solution = partial_solution.split("\n</think>", 1)[0].strip()

    if "deepseek-r1" in model.lower() or "deepscaler" in model.lower():
        s += user(f"Please reason step by step, and put your final answer within \\boxed{{}}. {problem}")
    elif "qwen" in model.lower() and "instruct" in model.lower():
        s += system(f"Please reason step by step, and put your final answer within \\boxed{{}}")
        s += user(f"{problem}")
    
    s += sgl.assistant_begin()
    s += partial_solution
    
    if not "</think>" in s.text():
        if partial_solution == "":
            s += "<think>\n\n</think>\n\n"
        else:
            s += " Hmm, I think this is enough to derive the final answer.\n\n**Final Answer**\n"
    
    num_tokens_generated = get_token_length(tokenizer, partial_solution) - 1
    max_tokens = max(args.max_tokens - num_tokens_generated, 512)
    
    forks = s.fork(args.num_conclusion)
    num_tokens_list = []
    solution_list = []
    answer_list = []
    correct_list = []
    for fork in forks:
        fork += gen("solution", max_tokens, temperature=args.temperature, top_p=args.top_p)
        num_tokens = fork.get_meta_info("solution")["completion_tokens"]
        num_tokens_list.append(num_tokens)
        solution_list.append(fork["solution"])
        answer = extract_answer(fork["solution"])
        if grade_answer(answer, ground_truth_answer):
            correct = 1
        else:
            correct = 0
        answer_list.append(answer)
        correct_list.append(correct)
    
    res = {
        "qid": qid,
        "problem": problem,
        "ground_truth_answer": ground_truth_answer,
        "partial_solution": s.text(),
        "max_tokens": max_tokens,
        "temperature": args.temperature,
        "top_p": args.top_p,
        "num_tokens_list": num_tokens_list,
        "answer_list": answer_list,
        "correct_list": correct_list,
        "solution_list": solution_list,
    }
    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(res, f, indent=4)
    print(f">>> Save to [{save_path}]")
    return


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--dataset", type=str, default='math_train')
    parser.add_argument("--num_inst", type=int, default=5000)
    parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-1.5B")
    parser.add_argument("--model_path", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
    parser.add_argument("--chat_template_type", type=str, default="deepseek-v3")
    parser.add_argument("--speed", type=str, default="normal")
    parser.add_argument("--num_threads", type=int, default=96)
    parser.add_argument("--host", type=str, default="http://localhost:30000")
    parser.add_argument("--max_tokens", type=int, default=8192)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=-1)
    parser.add_argument("--num_conclusion", type=int, default=4)
    parser.add_argument("--threshold", type=float, default=1.0)
    parser.add_argument("--num_raw", type=int, default=4)
    parser.add_argument("--load_dir", type=str, default="")
    parser.add_argument("--raw_subdir", type=str, default="raw")
    parser.add_argument("--rollout_subdir", type=str, default="rollout")
    parser.add_argument("--used_count_per_problem", type=int, default=2)

    return parser.parse_args()


def list_files(directory):
    if not os.path.isdir(directory):
        return []
    return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]


def get_token_length(tokenizer, solution):
    try:
        length = len(tokenizer(solution)['input_ids'])
    except:
        print(solution)
    return length


if __name__ == "__main__":
    args = parse_args()

    ###################
    ## Set seed
    ###################
    set_seed(args.seed)

    ###################
    ## Load dataset
    ###################
    input_path = dataset_paths[args.dataset]
    data = json.load(open(input_path,"r"))

    ###################
    ## Sample indices
    ###################
    sampled_indices = sample_indices(data, args.dataset, args.num_inst)
    if args.end == -1:
        args.end = len(data)
    
    ###################
    ## Load backend
    ###################
    backend = RuntimeEndpoint(args.host)
    
    ###################
    ## Load tokenizer
    ###################
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=True)
    
    ###################
    ## Save dir
    ###################
    save_dir = os.path.join(args.load_dir, "conclude")
    print(f">>> Save dir = [{save_dir}]")
    
    input_dict_list = []
    for qid in sampled_indices:
        if args.start <= qid < args.end:
            raw_file_path = os.path.join(args.load_dir, args.raw_subdir, f"{qid}.json")
            if os.path.exists(raw_file_path):
                with open(raw_file_path, "r", encoding="utf-8") as f:
                    raw_data = json.load(f)
                
                problem = raw_data["problem"]
                ground_truth_answer = raw_data["ground_truth_answer"]
                num_tokens_list = raw_data["num_tokens_list"]
                solution_list = raw_data["solution_list"]
                answer_list = raw_data["answer_list"]
                correct_list = raw_data["correct_list"]

                rollout_file_dir = os.path.join(args.load_dir, args.rollout_subdir, f"{qid}")
                rollout_file_paths = list_files(rollout_file_dir)
                existing_indices = [int(file.split(".json")[0]) for file in rollout_file_paths]
                
                yids = [i for i in range(len(solution_list)) if i in existing_indices]
                
                if ground_truth_answer and len(yids) > 0:
                    for yid in yids:
                        rollout_file_path = os.path.join(rollout_file_dir, f"{yid}.json")
                        with open (rollout_file_path, "r", encoding="utf-8") as f:
                            rollout_data = json.load(f)
                        
                        chunks = rollout_data["chunks"]
                        solution = rollout_data["solution"]
                        correct = rollout_data["correct"]
                        num_tokens = rollout_data["num_tokens"]
                        num_chunks = rollout_data["num_chunks"]
                        history = rollout_data["history"]

                        final_k = next(
                            (i for i, item in enumerate(history) if len(item["inner_correct_list"]) == 1),
                            len(history) 
                        )
                        pdist = []
                        for k in range(0, final_k):
                            p = np.mean(history[k]['inner_correct_list'])
                            pdist.append(p)
                        
                        consec = 0
                        k_target = None
                        for i, p in enumerate(pdist):
                            if p >= args.threshold:
                                consec += 1
                                if consec >= 1:
                                    k_target = i
                                    break
                            else:
                                consec = 0
                        
                        if k_target is None:
                            continue

                        save_path = os.path.join(save_dir, f"{qid}/{yid}.json")
                        os.makedirs(os.path.dirname(save_path), exist_ok=True)
                        
                        input_dict_list.append(
                            {
                                "args": args,
                                "qid": qid,
                                "yid": yid,
                                "model": args.model_name,
                                "tokenizer": tokenizer,
                                "problem": problem,
                                "ground_truth_answer": ground_truth_answer,
                                "original_num_tokens": num_tokens,
                                "original_solution": solution,
                                "correct": correct,
                                "chunks": chunks,
                                "history": history,
                                "max_tokens": args.max_tokens,
                                "save_path": save_path,
                            }
                        )

    chat_template = get_chat_template(args.chat_template_type)
    backend.chat_template = chat_template
    states = conclude.run_batch(
        input_dict_list,
        backend=backend, 
        num_threads=args.num_threads,
        progress_bar=True,
    )
    
    os._exit(0)