import argparse
import json
import os
import sys
import json
import re
import random
from utils import extract_answer, set_seed, sample_indices, dataset_paths, chunk_starters
from grader import grade_answer
from transformers import AutoTokenizer
sys.path.append("./sglang/python")
import sglang as sgl
from sglang import function, system, user, gen, RuntimeEndpoint
from sglang.lang.chat_template import get_chat_template


def answer_gen_prefix(model):
    if "deepseek-r1" in model.lower() or "deepscaler" in model.lower():
        return "... Wait, I suddenly got the final answer to the whole problem.\n\n**Final Answer**\n\n\\[ \\boxed{"
    elif "qwen" in model.lower() and "instruct" in model.lower():
        return "... Wait, I suddenly got the final answer to the whole problem.\n\nThe final answer is\n\\[ \\boxed{"
    else:
        raise NotImplementedError


def divide_into_chunks(text):
    primary_chunks = re.split(r"(</think>\n\n)", text)
    merged_chunks = []

    current_chunk = ""
    for i in range(len(primary_chunks)):
        part = primary_chunks[i]
        if part.startswith("</think>\n\n"):
            if current_chunk:
                current_chunk += part
                merged_chunks.append(current_chunk.strip())
                current_chunk = ""
        else:
            current_chunk += part
    
    if current_chunk.strip():
        merged_chunks.append(current_chunk.strip())
    
    processed_chunks = []
    if len(merged_chunks) == 2:
        for chunk in merged_chunks[:-1]:
            pattern = "|".join([re.escape(starter) for starter in chunk_starters])
            sub_chunks = re.split(f"({pattern})", chunk)
            sub_chunks = [sub_chunk.strip() for sub_chunk in sub_chunks if sub_chunk.strip()]
            result_chunks = []
            current_chunk = ""
            for sub_chunk in sub_chunks:
                if sub_chunk in chunk_starters:
                    if current_chunk:
                        result_chunks.append(current_chunk.strip())
                    current_chunk = sub_chunk  # 새로운 chunk 시작
                else:
                    current_chunk += " " + sub_chunk  # 현재 chunk에 추가
            if current_chunk:
                result_chunks.append(current_chunk.strip())
            processed_chunks.extend(result_chunks)
        processed_chunks.extend([merged_chunks[-1]])
    
    elif len(merged_chunks) == 1:
        for chunk in merged_chunks:
            pattern = "|".join([re.escape(starter) for starter in chunk_starters])
            sub_chunks = re.split(f"({pattern})", chunk)
            sub_chunks = [sub_chunk.strip() for sub_chunk in sub_chunks if sub_chunk.strip()]
            result_chunks = []
            current_chunk = ""
            for sub_chunk in sub_chunks:
                if sub_chunk in chunk_starters:
                    if current_chunk:
                        result_chunks.append(current_chunk.strip())
                    current_chunk = sub_chunk  # 새로운 chunk 시작
                else:
                    current_chunk += " " + sub_chunk  # 현재 chunk에 추가
            if current_chunk:
                result_chunks.append(current_chunk.strip())
            processed_chunks.extend(result_chunks)
    
    return processed_chunks, len(merged_chunks)


def truncate_after_last_dot(s, return_remnant=False):
    index = s.rfind(".") 
    if return_remnant:
        return s[:index+1] if index != -1 else s, s[index+1:] if index != -1 else ""
    return s[:index+1] if index != -1 else s


@function
def rollout(s, args, qid, model, problem, ground_truth_answer, num_tokens, solution, correct, max_tokens, save_path):
    if "deepseek-r1" in model.lower() or "deepscaler" in model.lower():
        s += user(f"Please reason step by step, and put your final answer within \\boxed{{}}. {problem}")
    elif "qwen" in model.lower() and "instruct" in model.lower():
        s += system(f"Please reason step by step, and put your final answer within \\boxed{{}}")
        s += user(f"{problem}")

    chunks, _ = divide_into_chunks(text=solution)
    answer_prefix = answer_gen_prefix(model)
    history = []
    s += sgl.assistant_begin()
    for idx in list(range(len(chunks)+1)):
        if idx != 0:
            if idx == 1:
                s += chunks[idx-1]
            else:
                s += args.delimiter + chunks[idx-1]
        
        current_chunk_text = chunks[idx-1] if idx > 0 else ""
        if "\\boxed{" in current_chunk_text:
            inner_answer = extract_answer(current_chunk_text)
            inner_correct = int(grade_answer(inner_answer, ground_truth_answer))
            history_res = {
                "chunk_idx": idx,
                "chunk_text": chunks[idx-1] if idx > 0 else "",
                "partial_solution": s.text(),
                "inner_solution_list": [],
                "inner_answer_list": [inner_answer],
                "inner_correct_list": [inner_correct],
            }
            history.append(history_res)
            continue

        outer_fork = s.fork(1)[0]
        if idx == 0:
            outer_fork += "<think>\n"
        if "</think>" not in s.text():
            outer_fork += "\n</think>\n\n"
        outer_fork += answer_prefix
        inner_forks = outer_fork.fork(args.num_rollout)
        inner_solution_list = []
        inner_answer_list = []
        inner_correct_list = []
        for inner_fork in inner_forks:
            inner_fork += gen("answer", max_tokens, temperature=args.temperature, top_p=args.top_p)
            inner_solution_list.append(inner_fork["answer"])
            inner_answer = extract_answer(answer_prefix+inner_fork["answer"])
            inner_answer_list.append(inner_answer)
            inner_correct_list.append(int(grade_answer(inner_answer, ground_truth_answer)))

        history_res = {
            "chunk_idx": idx,
            "chunk_text": chunks[idx-1] if idx > 0 else "",
            "partial_solution": s.text(),
            "inner_solution_list": inner_solution_list,
            "inner_answer_list": inner_answer_list,
            "inner_correct_list": inner_correct_list,
        }
        history.append(history_res)
    
    res = {
        "model": model,
        "qid": qid,
        "problem": problem,
        "ground_truth_answer": ground_truth_answer,
        "num_tokens": num_tokens,
        "solution": solution,
        "correct": correct,
        "num_chunks": len(chunks),
        "chunks": chunks,
        "answer_prefix": answer_prefix,
        "temperature": args.temperature,
        "max_tokens": max_tokens,
        "top_p": args.top_p,
        "num_rollout": args.num_rollout,
        "history": history,
    }
    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(res, f, indent=4)
    print(f">>> Save to [{save_path}]")
    return


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--dataset", type=str, default='math_train')
    parser.add_argument("--num_inst", type=int, default=5000)
    parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-1.5B")
    parser.add_argument("--model_path", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
    parser.add_argument("--chat_template_type", type=str, default="deepseek-v3")
    parser.add_argument("--speed", type=str, default="normal")
    parser.add_argument("--num_threads", type=int, default=96)
    parser.add_argument("--host", type=str, default="http://localhost:30000")
    parser.add_argument("--max_tokens", type=int, default=8192)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=-1)
    parser.add_argument("--num_rollout", type=int, default=10)
    parser.add_argument("--delimiter", type=str, default="\n\n")
    parser.add_argument("--load_dir", type=str, default="")
    parser.add_argument("--used_count_per_problem", type=int, default=2)
    parser.add_argument("--num_raw", type=int, default=4)
    return parser.parse_args()


def list_files(directory):
    if not os.path.isdir(directory):
        return []
    return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]


def get_token_length(tokenizer, solution):
    try:
        length = len(tokenizer(solution)['input_ids'])
    except:
        print(solution)
    return length


if __name__ == "__main__":
    args = parse_args()

    ###################
    ## Set seed
    ###################
    set_seed(args.seed)

    ###################
    ## Load dataset
    ###################
    input_path = dataset_paths[args.dataset]
    data = json.load(open(input_path,"r"))

    ###################
    ## Sample indices
    ###################
    sampled_indices = sample_indices(data, args.dataset, args.num_inst)
    if args.end == -1:
        args.end = len(data)
    
    ###################
    ## Load backend
    ###################
    backend = RuntimeEndpoint(args.host)
    
    ###################
    ## Load tokenizer
    ###################
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=True)
    
    ###################
    ## Save dir
    ###################
    save_dir = os.path.join(args.load_dir, "rollout")
    
    input_dict_list = []
    for qid in sampled_indices:
        if args.start <= qid < args.end:
            raw_file_path = os.path.join(args.load_dir, "raw", f"{qid}.json")
            if os.path.exists(raw_file_path):
                with open(raw_file_path, "r", encoding="utf-8") as f:
                    raw_data = json.load(f)
                
                problem = raw_data["problem"]
                ground_truth_answer = raw_data["ground_truth_answer"]
                num_tokens_list = raw_data["num_tokens_list"]
                solution_list = raw_data["solution_list"]
                answer_list = raw_data["answer_list"]
                correct_list = raw_data["correct_list"]
                
                files = list_files(os.path.join(save_dir, f"{qid}"))
                existing_indices = [int(file.split(".json")[0]) for file in files]
                existing_indices = [i for i in existing_indices if i in list(range(args.num_raw))]

                if ground_truth_answer:
                    yids = [i for i in range(min(len(solution_list), args.num_raw)) if i not in existing_indices]
                    for yid in yids:
                        save_path = os.path.join(save_dir, f"{qid}/{yid}.json")
                        os.makedirs(os.path.dirname(save_path), exist_ok=True)
                        
                        max_tokens = max(args.max_tokens, get_token_length(tokenizer, ground_truth_answer) + 16)
                        input_dict_list.append(
                            {
                                "args": args,
                                "qid": qid,
                                "model": args.model_name,
                                "problem": problem,
                                "ground_truth_answer": ground_truth_answer,
                                "num_tokens": num_tokens_list[yid],
                                "solution": solution_list[yid],
                                "correct": correct_list[yid],
                                "max_tokens": max_tokens,
                                "save_path": save_path,
                            }
                        )

    chat_template = get_chat_template(args.chat_template_type)
    backend.chat_template = chat_template
    states = rollout.run_batch(
        input_dict_list,
        backend=backend, 
        num_threads=args.num_threads, 
        progress_bar=True,
    )
    os._exit(0)