"""
Prune intermediate patterns
"""
import argparse
import json
import os
import sys
import json
import re
import random
import requests
import copy
from utils import extract_answer, set_seed, sample_indices, dataset_paths
from grader import grade_answer
from transformers import AutoTokenizer
from rollout import divide_into_chunks
sys.path.append("./sglang/python")
import sglang as sgl
from sglang import function, system, user, assistant, gen, RuntimeEndpoint
from sglang.lang.chat_template import get_chat_template


def check_remove(args, problem, ground_truth_answer, text):
    marker = r'\boxed{'
    last_index = text.rfind(marker)
    if last_index == -1:
        text_trimmed = text
    else:
        text_trimmed = text[:last_index + len(marker)]
    response = requests.post(
        f"http://localhost:{args.lrm_port}/generate",
        json={
            "text": f"<｜User｜>Please reason step by step, and put your final answer within \\boxed{{}}. {problem}" + f"<｜Assistant｜>{text_trimmed}",
            "sampling_params": {
                "temperature": 0,
                "max_new_tokens": max(32, get_token_length(ground_truth_answer) + 16),
            },
        },
    )
    res = response.json()
    pred = extract_answer(text_trimmed+res["text"])
    return grade_answer(pred, ground_truth_answer)


def get_token_length(tokenizer, solution):
    try:
        length = len(tokenizer(solution)['input_ids'])
    except:
        print(solution)
    return length


def extract_summary(text):
    pattern = r"\[FINAL_RECONSTRUCTED_REASONING\]\s*\"\"\"(.*?)\"\"\""
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return None


def extract_keep_remove(text):
    results = re.findall(r"\[Chunk \d+\]\s+(.*)", text)
    return results


def is_pass(text, tokenizer):
    matches = re.findall(r'\\boxed\{.*?\}', text)
    if len(matches) == 1:
        num_tokens = len(tokenizer.encode(text, add_special_tokens=False))
        if num_tokens > 100:
            return True
        else:
            return False
    else:
        return False


def update_keep_remove(args, problem, ground_truth_answer, chunks, keep_remove, check_func):
    keep_remove2 = copy.deepcopy(keep_remove)
    for i in reversed(range(len(keep_remove))):
        if keep_remove[i].lower() == "remove":
            future_removed_indices = [j for j in range(i+1, len(keep_remove2)) if keep_remove[j] == "REMOVE"]
            excluded_indices = set(future_removed_indices + [i])
            filtered_chunks = [chunk for idx, chunk in enumerate(chunks) if idx not in excluded_indices]
            joined_text = "\n\n".join(filtered_chunks)
            if not check_func(args, problem, ground_truth_answer, joined_text):
                keep_remove2[i] = "KEEP AS IS"
    assert len(keep_remove) == len(keep_remove2)
    return keep_remove2


@function
def prune(s, args, qid, yid, yyid, model, tokenizer, original_problem, ground_truth_answer, solution, correct, max_tokens, save_path):
    if not ("<think>" in solution and "</think>" in solution):
        return

    if "llama-3" in model.lower() and "inst" in model.lower():
        with open(f"prompt/aux_system.md", "r", encoding="utf-8") as f:
            system_prompt = f.read()
        s += system(system_prompt)
        with open(f"prompt/aux_user.md", "r", encoding="utf-8") as f:
            user_prompt = f.read()
    else:
        raise NotImplementedError

    chunks, _ = divide_into_chunks(text=solution)

    ## Chunks to str
    formatted_chunks = ""
    exit = False
    for i, chunk in enumerate(chunks):
        if "<think>" in chunk:
            chunk = chunk.replace("<think>", "")
        if "</think>" in chunk:
            chunk = chunk.replace("</think>", "")
            exit = True
        formatted_chunks += f"[Chunk {i+1}]\n{chunk}\n\n"
        if exit:
            i_exit = i
            break
    
    problem = "Please reason step by step, and put your final answer within \\boxed{{}}. " + original_problem
    user_prompt = user_prompt.format(problem=problem, ground_truth_answer=ground_truth_answer, chunks=formatted_chunks)
    
    s += user(user_prompt)
    forks = s.fork(args.num_try)
    for fork in forks:
        fork += assistant(gen("prune", max_tokens, temperature=args.temperature, top_p=args.top_p, stop="[FINAL_RECONSTRUCTED_REASONING]"))
        keep_remove = extract_keep_remove(fork["prune"]) ## ["KEEP AS IS", "REMOVE", ...]

        ## Rollout to ensure correctness
        keep_remove2 = update_keep_remove(args, original_problem, ground_truth_answer, chunks[:-1], keep_remove, check_remove)
        reconstructed_output = "[OUTPUT]\n\n[CHUNK_FILTERING_RESULTS]\n"
        for r, decision in enumerate(keep_remove2):
            reconstructed_output += f"[Chunk {r+1}] {decision}\n"
        reconstructed_output += "\n[FINAL_RECONSTRUCTED_REASONING]\n\"\"\""
        
        fork2 = s.fork(1)[0]
        fork2 += "<|start_header_id|>assistant<|end_header_id|>\n\n" + reconstructed_output
        fork2 += gen("summary", max_tokens, temperature=args.temperature, top_p=args.top_p)
        
        extracted_summary = extract_summary(reconstructed_output+fork2["summary"])
        if extracted_summary and is_pass(extracted_summary, tokenizer):
            break
        else:
            extracted_summary = None

    if not extracted_summary:
        return

    if len(chunks) - 1 > i_exit:
        if len(chunks) - 1 - i_exit == 1:
            for i in range(i_exit+1, len(chunks)):
                extracted_summary = "<think>\n" + extracted_summary + "\n</think>\n\n" + chunks[i]
        else:
            return
    num_tokens = len(tokenizer.encode(extracted_summary, add_special_tokens=False))
    original_num_tokens = len(tokenizer.encode(solution, add_special_tokens=False))

    res = {
        "model": model,
        "qid": qid,
        "yid": yid,
        "yyid": yyid,
        "problem": original_problem,
        "ground_truth_answer": ground_truth_answer,
        "correct": correct,
        "solution": solution,
        "num_chunks": len(chunks),
        "chunks": chunks,
        "keep_remove": keep_remove,
        "keep_remove2": keep_remove2,
        "reconstructed_output": reconstructed_output,
        "original_num_tokens": original_num_tokens,
        "num_tokens": num_tokens,
        "summary": extracted_summary,
        "generated": fork2["summary"],
        "temperature": args.temperature,
        "max_tokens": max_tokens,
        "top_p": args.top_p,
    }
    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(res, f, indent=4)
    print(f">>> Save to [{save_path}]")
    return


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_threads", type=int, default=96)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=10828)
    parser.add_argument("--dataset", type=str, default='math_train')
    parser.add_argument("--num_inst", type=int, default=5000)
    parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.3-70B-Instruct")
    parser.add_argument("--model_path", type=str, default="meta-llama/Llama-3.3-70B-Instruct")
    parser.add_argument("--host", type=str, default="http://localhost:50000")
    parser.add_argument("--lrm_port", type=str, default="30000")
    parser.add_argument("--chat_template_type", type=str, default="llama-3-instruct")
    parser.add_argument("--max_tokens", type=int, default=16384)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--delimiter", type=str, default="\n\n")
    parser.add_argument("--num_raw", type=int, default=4)
    parser.add_argument("--num_try", type=int, default=3)
    parser.add_argument("--load_dir", type=str, default="")
    parser.add_argument("--raw_subdir", type=str, default="raw")
    parser.add_argument("--conclude_subdir", type=str, default="conclude")
    return parser.parse_args()


def list_files(directory):
    if not os.path.isdir(directory):
        return []
    return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]


if __name__ == "__main__":
    args = parse_args()
    input_path = dataset_paths[args.dataset]
    data = json.load(open(input_path,"r"))

    ###################
    ## Set seed
    ###################
    set_seed(args.seed)
    
    ###################
    ## Sample indices
    ###################
    sampled_indices = sample_indices(data, args.dataset, args.num_inst)
    if args.end == -1:
        args.end = len(data)
    
    ###################
    ## Load backend
    ###################
    backend = RuntimeEndpoint(args.host)
    
    ###################
    ## Load tokenizer
    ###################
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=True)
    
    ###################
    ## Save dir
    ###################
    save_dir = os.path.join(args.load_dir, "prune")
    print(f">>> Save dir = [{save_dir}]")
    
    input_dict_list = []
    selected_id_set = set()
    for qid in sampled_indices:
        if args.start <= qid < args.end:
            file_path = os.path.join(args.load_dir, args.raw_subdir, f"{qid}.json")
            if os.path.exists(file_path):
                with open(file_path, "r", encoding="utf-8") as f:
                    raw_data = json.load(f)
                solution_list = raw_data["solution_list"]

                for yid in range(len(solution_list)):
                    conclude_path = os.path.join(args.load_dir, args.conclude_subdir, f"{qid}/{yid}.json")
                    if not os.path.exists(conclude_path):
                        continue
                    with open(conclude_path, "r", encoding="utf-8") as f:
                        data = json.load(f)
                    
                    partial_solution = data["partial_solution"]
                    index = partial_solution.find("<think>")
                    if index != -1:
                        partial_solution = partial_solution[index:]
                    else:
                        continue
                    
                    if "<think> Hmm, I think this is enough to derive the final answer.\n\n**Final Answer**\n" in partial_solution:
                        continue
                    
                    if "<think>\n\n</think>" in partial_solution:
                        continue
                    
                    solution_list = data["solution_list"]
                    correct_list = data["correct_list"]
                    num_tokens_list = data["num_tokens_list"]
                    yyids_filtered = [i for i, val in enumerate(correct_list) if val == 1]
                    yyids_sorted = sorted(yyids_filtered, key=lambda i: num_tokens_list[i])
                    
                    for yyid in yyids_sorted:
                        solution = partial_solution + solution_list[yyid]
                        before = solution.split("</think>")[0]
                        if (before.count("**Final Answer**") >= 2) or (before.count("boxed{") >= 2):
                            continue
                        else:
                            save_path = os.path.join(save_dir, f"{qid}/{yid}.json")
                            os.makedirs(os.path.dirname(save_path), exist_ok=True)
                            input_dict_list.append(
                                {
                                    "args": args,
                                    "qid": qid,
                                    "yid": yid,
                                    "yyid": yyid,
                                    "model": args.model_name,
                                    "tokenizer": tokenizer,
                                    "original_problem": data["problem"],
                                    "ground_truth_answer": data["ground_truth_answer"],
                                    "solution": solution,
                                    "correct": correct_list[yyid],
                                    "max_tokens": args.max_tokens,
                                    "save_path": save_path,
                                }
                            )
                            break

    print(f">>> {len(selected_id_set)} selected IDs...")
    chat_template = get_chat_template(args.chat_template_type)
    backend.chat_template = chat_template
    states = prune.run_batch(
        input_dict_list,
        backend=backend, 
        num_threads=args.num_threads,
        progress_bar=True,
    )
    os._exit(0)