import os
import json
import time

from transformers import AutoTokenizer
import logging
import random
import numpy as np
import torch
from tqdm import tqdm
from math_answer_check import extract_boxed_answer
from math_verify import parse, verify

from vllm import LLM, SamplingParams
from datasets import load_dataset
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


MODEL_PATHS = {
    "dpsk-distill-qwen2.5-32b": "user/models/DeepSeek-R1-Distill-Qwen-32B",
    "qwq-32b": "user/models/QwQ-32B",
    "dapo-qwen-32b": "user/models/DAPO-Qwen-32B",
    "qwen2.5-32b-instruct": "user/models/Qwen2.5-32B-Instruct",
    "qwen2.5-7b-instruct": "user/models/Qwen2.5-7B-Instruct",
    "qwen2.5-math-7b-instruct": "user/models/Qwen2.5-Math-7B-Instruct",
}




#TEMPERATURES = [0.0, 0.2, 0.5, 0.8, 1.0]
TEMPERATURES = [0.0,0.6]
DECOMPOSITION_SYSTEM_PROMPTS = [
    """You are a reasoning strategist.
Your job is to break down a complex problem into 2–4 high-level reasoning steps.
Focus only on outlining the general approach or strategy.
Do not include any numbers, formulas, or final answers.
Avoid specific calculations or details—only describe the logic behind solving the problem.""",

    """You're an organizer responsible for only giving the skeleton (not the full content) for answering the question.
Provide the skeleton in a list of points (numbered 1., 2., 3., etc.) to answer the question.
Each skeleton point should be very short with only 3–5 words.
Generally, the skeleton should have 3–10 points.
Do not answer the question—only structure the thinking process."""
]


USER_PROMPT_TEMPLATE = "Please break down the following problem:\n\nProblem: {problem}"


def set_seed(seed=42):
    """Set seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # If you're using other libraries that use randomness, set their seeds here too




def load_math500(sketch_dataset=None):
    if sketch_dataset is None:
        with open("user/user1/lm-evaluation-harness/lm_eval/tasks/math_500/test.jsonl", "r") as f:
            data_list = [json.loads(line) for line in f]
    else:
        with open(f"user/user1/lm-evaluation-harness/pivotal_step/results/sketch_dataset_full/math500/{sketch_dataset}.jsonl", "r") as f:
            full_list = json.load(f)
            data_list = [item for item in full_list if isinstance(item, dict)]
    return data_list



def load_aime(sketch_dataset=None):
    if sketch_dataset is None:
        with open("user/user1/lm-evaluation-harness/aime2024/aime24_nofigures.jsonl", "r") as f:
            data_list = [json.loads(line) for line in f]
    else:
        with open(f"user/user1/lm-evaluation-harness/pivotal_step/results/sketch_dataset_full/aime/{sketch_dataset}.jsonl", "r") as f:
            full_list = json.load(f)
            data_list = [item for item in full_list if isinstance(item, dict)]
    return data_list

def load_gsm8k(sketch_dataset=None):
    if sketch_dataset is None:
        dataset = load_dataset("parquet", data_files="user/user1/lm-evaluation-harness/openai-gsm8k/main/test-00000-of-00001.parquet")
        dataset_dict = dataset["train"].to_dict()
        questions = dataset_dict["question"]
        answers = dataset_dict["answer"]

        data_list = [{"problem": q, "solution": a} for q, a in zip(questions, answers)]
    else:
        with open(f"user/user1/lm-evaluation-harness/pivotal_step/results/sketch_dataset_full/gsm8k/{sketch_dataset}.jsonl", "r") as f:
            full_list = json.load(f)
            data_list = [item for item in full_list if isinstance(item, dict)]
    return data_list


def get_no_thinking_messages(model_name, problem):
    if "instruct" in model_name:
        return [
            {"role": "system", "content": "You are a helpful math assistant. Please reason step by step to solve the problem, and put your final answer within \\boxed{}."},
            {"role": "user", "content": problem},
        ]
    else:
        if "dapo" in model_name:
            return f"A conversation between user and assistant. The user asks a question, and the assistant solves it. The time limit is set to 20,480 tokens. If the assistant's response exceeds this limit, a progressively increasing penalty with the number of tokens exceeded will be applied.\nuser\nSolve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.\n{problem}\nRemember to put your answer on its own line after \"Answer:\".\nassistant: okay I have finished thinking."
        elif "qwq" in model_name or "dpsk" in model_name:
            return [
                {
                    "role": "user",
                    "content": f"Please reason step by step, and put your final answer within \\boxed{{}}.\n\n{problem}"
                },
                {
                    "role": "assistant",
                    "content": "Okay I have finished thinking."
                }
            ]

        else:
            return [{"role": "user", "content": problem}]

def get_baseline_messages(model_name, problem):
    if "instruct" in model_name:
        return [
            {"role": "system", "content": "You are a helpful math assistant. Please reason step by step to solve the problem, and put your final answer within \\boxed{}."},
            {"role": "user", "content": problem}
        ]
    else:
        if "dapo" in model_name:
            return f"A conversation between user and assistant. The user asks a question, and the assistant solves it. The time limit is set to 20,480 tokens. If the assistant's response exceeds this limit, a progressively increasing penalty with the number of tokens exceeded will be applied.\nuser\nSolve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.\n{problem}\nRemember to put your answer on its own line after \"Answer:\".\nassistant"
        elif "qwq" in model_name or "dpsk" in model_name:

            return [{"role": "user", "content": f"Please reason step by step, and put your final answer within \\boxed{{}}.\n\n{problem}"}]
        else:
            return [{"role": "user", "content": problem}]

def get_messages(model_name, problem, prompt_id=0, decomposition=None):

    if decomposition is None:

        return [
            {"role": "system", "content": DECOMPOSITION_SYSTEM_PROMPTS[prompt_id]},
            {"role": "user", "content": USER_PROMPT_TEMPLATE.format(problem=problem)}
        ]


    else:
        if "instruct" in model_name:
            system_prompt = "You are a helpful math assistant. Use only the following steps to solve the problem. Do not change or add steps. Show the work for each step briefly, and place the final answer in \\boxed{{}}."
            return [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"Problem: {problem}\n\nSteps: {decomposition}"}
            ]


        elif "dpsk" in model_name or "qwq" in model_name:

            return [
                {
                    "role": "user",
                    "content": f"Use only the following steps to solve the problem. Do not change or add steps. Show the work for each step briefly, and place the final answer in \\boxed{{}}.\n\nProblem: {problem}\n\nSteps: {decomposition}"
                }
            ]


        elif "dapo" in model_name:
            return f"""A conversation between a user and an assistant. The user asks a math question, and the assistant solves it. The assistant must use only the given steps to solve the problem. No steps may be changed, removed, or added. For each step, briefly show the work.

        The total response must not exceed 20,480 tokens. If this limit is exceeded, a progressively increasing penalty will be applied.

        user
        Solve the following math problem step by step. Use only the steps listed below to solve the problem.

        Problem: {problem}
        Steps: {decomposition}

        assistant"""



        else:
            # Default: just user prompt
            return [{"role": "user", "content": problem}]

def generate_problem_decomposition(dataset_name, model_name, output_dir, num_problems=None):
    if dataset_name == "math500":
        problems = load_math500()
    elif dataset_name == "aime":
        problems = load_aime()
    elif dataset_name == "gsm8k":
        problems = load_gsm8k()
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    if num_problems and dataset_name != "aime":
        problems = problems[:num_problems]

    model_path = MODEL_PATHS[model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    llm_kwargs = {
        "model": model_path,
        "tensor_parallel_size": 4,
        "gpu_memory_utilization": 0.8,
        "trust_remote_code": True
    }

    llm = LLM(**llm_kwargs)
    sampling_params = SamplingParams(
        top_p=0.95,
        max_tokens=8192,
    )

    total_problems = len(problems)

    logger.info(f"Starting decomposition of {total_problems} problems in {dataset_name} dataset...")

    for prompt_id in range(len(DECOMPOSITION_SYSTEM_PROMPTS)):

        all_results = [DECOMPOSITION_SYSTEM_PROMPTS[prompt_id]]
        for i, problem_data in enumerate(tqdm(problems, desc=f"Prompt {prompt_id}")):
            problem_id = problem_data.get("id", None)
            problem = problem_data.get("problem", "")
            gold_solution = problem_data.get("solution", "")
            gold_answer = problem_data.get("answer", "")
            level = problem_data.get("level", None)

            try:
                start_time = time.time()
                messages = get_messages(model_name, problem, prompt_id)
                text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                outputs = llm.generate([text], sampling_params)
                decomposition = outputs[0].outputs[0].text
                decomposition_time = time.time() - start_time
                decomposition_tokens = count_tokens(decomposition, tokenizer)
                metrics = {
                    "decomposition_time": decomposition_time,
                    "decomposition_tokens": decomposition_tokens,
                }
                result = {
                    "problem_id": problem_id,
                    "problem": problem,
                    "decomposition": decomposition,
                    "gold_solution": gold_solution,
                    "gold_answer": gold_answer,
                    "level": level,
                    "metrics": metrics
                }
            except Exception as e:
                logger.error(f"[Prompt {prompt_id}] Decomposition failed on problem {problem_id}: {e}")

            all_results.append(result)

        output_path = f"{output_dir}/{dataset_name}/{model_name}_decomposition_{prompt_id}.jsonl"


        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(all_results, f, ensure_ascii=False, indent=2)

        logger.info(f"Saved results to {output_path}")



    logger.info(f"Completed decomposition for all {dataset_name} problems across {len(DECOMPOSITION_SYSTEM_PROMPTS)} prompts.")





def count_tokens(text: str, tokenizer) -> int:
    """Count the number of tokens in the text"""
    return len(tokenizer.encode(text))





def solve_and_save_collaboratively(dataset_name, model_name, output_dir, num_problems=None,sketch_dataset=None):
    """Collaboratively solve problems and save results"""

    if dataset_name == "math500":
        problems = load_math500(sketch_dataset)
    elif dataset_name == "aime":
        problems = load_aime(sketch_dataset)
    elif dataset_name == "gsm8k":
        problems = load_gsm8k(sketch_dataset)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    if num_problems and dataset_name != "aime":
        problems = problems[:num_problems]


    model_path = MODEL_PATHS[model_name]
    print(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)


    llm = LLM(
        model=model_path,
        tensor_parallel_size=4,
        gpu_memory_utilization=0.8,
        trust_remote_code=True
    )
    if model_name == "dapo-qwen-32b":
        sampling_params = SamplingParams(
            top_p=0.7,
            temperature=1,
            max_tokens=32768,
        )
    else:
        sampling_params = SamplingParams(
            top_p=0.95,
            temperature=0.6,
            max_tokens=32768,
        )

    total_problems = len(problems)

    logger.info(f"Starting collaborative solving of {total_problems} problems in {dataset_name} dataset...")

    all_results = []

    for i, problem_data in enumerate(tqdm(problems)):
        problem_id = problem_data.get("problem_id", str(i))
        problem = problem_data.get("problem", "")
        gold_solution = problem_data.get("gold_solution", "")
        gold_answer = problem_data.get("gold_answer", "")
        level = problem_data.get("level", None)
        try:

            decomposition = problem_data.get("decomposition", "")
            start_time = time.time()
            messages = get_messages(model_name, problem, decomposition=decomposition)
            if 'dapo' in model_name:
                outputs = llm.generate(messages, sampling_params)
            else:
                text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                outputs = llm.generate([text], sampling_params)

            solution = outputs[0].outputs[0].text
            solution_time = time.time() - start_time
            solution_tokens = count_tokens(solution, tokenizer)
            #prediction = parse(extract_boxed_answer(solution))
            prediction = parse(solution)
            gold = parse(gold_answer)
            metrics = {
                "decomposition_time": problem_data.get("metrics", {}).get("decomposition_time", 0),
                "decomposition_tokens": problem_data.get("metrics", {}).get("decomposition_tokens", 0),
                "solution_time": solution_time,
                "solution_tokens": solution_tokens,
                "total_time": problem_data.get("metrics", {}).get("decomposition_time", 0) + solution_time,
                "total_tokens": problem_data.get("metrics", {}).get("decomposition_tokens", 0) + solution_tokens,
                "verify": verify(gold, prediction)
            }
            print(metrics["verify"])
            result = {
                "problem_id": problem_id,
                "problem": problem,
                "decomposition": decomposition,
                "gold_solution": gold_solution,
                "gold_answer": gold_answer,
                "level": level,
                "model_response": solution,
                "metrics": metrics
            }
        except Exception as e:
            logger.error(f"Solution failed on problem {problem_id}: {e}")
            result = {
                "problem_id": problem_id,
                "problem": problem,
                "decomposition": decomposition,
                "gold_solution": gold_solution,
                "gold_answer": gold_answer,
                "level": level,
                "error": str(e)
            }

        all_results.append(result)

    output_path = f"{output_dir}/{dataset_name}/{model_name}_sketch_dataset_{sketch_dataset}.jsonl"


    os.makedirs(os.path.dirname(output_path), exist_ok=True)


    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    logger.info(f"Saved results to {output_path}")



def solve_and_save_baseline(dataset_name, model_name, output_dir, num_problems=None):
    """Collaboratively solve problems and save results"""

    if dataset_name == "math500":
        problems = load_math500()
    elif dataset_name == "aime":
        problems = load_aime()
    elif dataset_name == "gsm8k":
        problems = load_gsm8k()
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    if num_problems and dataset_name != "aime":
        problems = problems[:num_problems]


    model_path = MODEL_PATHS[model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    llm = LLM(
        model=model_path,
        tensor_parallel_size=4,
        gpu_memory_utilization=0.8,
        trust_remote_code=True
    )
    if model_name == "dapo-qwen-32b":
        sampling_params = SamplingParams(
            top_p=0.7,
            temperature=1,
            max_tokens=32768,

        )
    else:
        sampling_params = SamplingParams(
            top_p=0.95,
            temperature=0.6,
            max_tokens=32768,

        )

    total_problems = len(problems)

    logger.info(f"Starting baseline solving of {total_problems} problems in {dataset_name} dataset...")

    all_results = []

    for i, problem_data in enumerate(tqdm(problems)):
        problem_id = problem_data.get("id", None)
        problem = problem_data.get("problem", "")
        gold_solution = problem_data.get("solution", "")
        gold_answer = problem_data.get("answer", "")
        level = problem_data.get("level", None)
        try:
            start_time = time.time()
            messages = get_baseline_messages(model_name, problem)
            if "dapo" in model_name:
                outputs = llm.generate(messages, sampling_params)
            else:
                text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                outputs = llm.generate([text], sampling_params)

            # Get all generated solutions
            solutions = [output.text for output in outputs[0].outputs]
            # Select the shortest solution
            solution = min(solutions, key=len)
            solution_time = time.time() - start_time
            solution_tokens = count_tokens(solution, tokenizer)
            prediction = parse(solution)
            gold = parse(gold_answer)
            metrics = {
                "solution_time": solution_time,
                "solution_tokens": solution_tokens,
                "verify": verify(gold, prediction),
                "best_of": best_of,
                "all_solutions": solutions  # Store all solutions for analysis
            }
            result = {
                "problem_id": problem_id,
                "problem": problem,
                "gold_solution": gold_solution,
                "gold_answer": gold_answer,
                "level": level,
                "model_response": solution,
                "metrics": metrics
            }
        except Exception as e:
            logger.error(f"Solution failed on problem {problem_id}: {e}")
            result = {
                "problem_id": problem_id,
                "problem": problem,
                "gold_solution": gold_solution,
                "gold_answer": gold_answer,
                "level": level,
                "error": str(e)
            }

        all_results.append(result)

    output_path = f"{output_dir}/{dataset_name}/{model_name}_baseline.jsonl"

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    logger.info(f"Saved results to {output_path}")


def solve_and_save_baseline_best_of(dataset_name, model_name, output_dir, num_problems=None, best_of=5):
    """Solve problems using best of N sampling and select the shortest answer.

    Args:
        dataset_name (str): Name of the dataset to use
        model_name (str): Name of the model to use
        output_dir (str): Directory to save results
        num_problems (int, optional): Number of problems to solve. Defaults to None.
        best_of (int, optional): Number of candidate solutions to generate. Defaults to 5.
    """
    if dataset_name == "math500":
        problems = load_math500()
    elif dataset_name == "aime":
        problems = load_aime()
    elif dataset_name == "gsm8k":
        problems = load_gsm8k()
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    if num_problems and dataset_name != "aime":
        problems = problems[:num_problems]


    model_path = MODEL_PATHS[model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    llm = LLM(
        model=model_path,
        tensor_parallel_size=4,
        gpu_memory_utilization=0.8,
        trust_remote_code=True
    )
    if model_name == "dapo-qwen-32b":
        sampling_params = SamplingParams(
            top_p=0.7,
            temperature=1,
            max_tokens=32768,
        )
    else:
        sampling_params = SamplingParams(
            top_p=0.95,
            temperature=0.6,
            max_tokens=32768,
        )

    total_problems = len(problems)

    logger.info(f"Starting best-of-{best_of} baseline solving of {total_problems} problems in {dataset_name} dataset...")

    all_results = []

    for i, problem_data in enumerate(tqdm(problems)):
        problem_id = problem_data.get("id", None)
        problem = problem_data.get("problem", "")
        gold_solution = problem_data.get("solution", "")
        gold_answer = problem_data.get("answer", "")
        level = problem_data.get("level", None)
        try:
            start_time = time.time()
            messages = get_baseline_messages(model_name, problem)

            # Generate multiple solutions
            solutions = []
            for _ in range(best_of):
                if "dapo" in model_name:
                    outputs = llm.generate(messages, sampling_params)
                else:
                    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                    outputs = llm.generate([text], sampling_params)
                solutions.append(outputs[0].outputs[0].text)

            # Calculate token lengths for all solutions
            solution_token_lengths = [count_tokens(sol, tokenizer) for sol in solutions]
            # Select the shortest solution
            solution = min(solutions, key=len)
            solution_time = time.time() - start_time
            solution_tokens = count_tokens(solution, tokenizer)
            prediction = parse(solution)
            gold = parse(gold_answer)

            # Calculate metrics for all solutions
            all_predictions = [parse(sol) for sol in solutions]
            all_verifications = [verify(gold, pred) for pred in all_predictions]

            metrics = {
                "solution_time": solution_time,
                "solution_tokens": solution_tokens,
                "verify": verify(gold, prediction),
                "best_of": best_of,
                "all_solutions": solutions,  # Store all solutions for analysis
                "all_verifications": all_verifications,  # Store verification results for all solutions
                "all_token_lengths": solution_token_lengths,  # Store token lengths for all solutions
                "shortest_solution_index": solutions.index(solution),  # Index of the selected shortest solution
                "best_solution_index": all_verifications.index(max(all_verifications)),  # Index of the best solution
                "shortest_token_length": min(solution_token_lengths),  # Length of shortest solution in tokens
                "longest_token_length": max(solution_token_lengths),  # Length of longest solution in tokens
                "avg_token_length": sum(solution_token_lengths) / len(solution_token_lengths)  # Average token length
            }

            result = {
                "problem_id": problem_id,
                "problem": problem,
                "gold_solution": gold_solution,
                "gold_answer": gold_answer,
                "level": level,
                "model_response": solution,
                "metrics": metrics
            }
        except Exception as e:
            logger.error(f"Solution failed on problem {problem_id}: {e}")
            result = {
                "problem_id": problem_id,
                "problem": problem,
                "gold_solution": gold_solution,
                "gold_answer": gold_answer,
                "level": level,
                "error": str(e)
            }

        all_results.append(result)

    output_path = f"{output_dir}/{dataset_name}/{model_name}_baseline_best_of_{best_of}.jsonl"

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    logger.info(f"Saved results to {output_path}")




def solve_and_save_no_thinking_baseline(dataset_name, model_name, output_dir, num_problems=None):
    """Collaboratively solve problems and save results"""

    if dataset_name == "math500":
        problems = load_math500()
    elif dataset_name == "aime":
        problems = load_aime()
    elif dataset_name == "gsm8k":
        problems = load_gsm8k()
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    if num_problems and dataset_name != "aime":
        problems = problems[:num_problems]


    model_path = MODEL_PATHS[model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    llm = LLM(
        model=model_path,
        tensor_parallel_size=4,
        gpu_memory_utilization=0.8,
        trust_remote_code=True
    )
    if model_name == "dapo-qwen-32b":
        sampling_params = SamplingParams(
            top_p=0.7,
            temperature=1,
            max_tokens=32768,
        )
    else:
        sampling_params = SamplingParams(
            top_p=0.95,
            temperature=0.6,
            max_tokens=32768,
        )

    total_problems = len(problems)

    logger.info(f"Starting baseline solving of {total_problems} problems in {dataset_name} dataset...")

    all_results = []

    for i, problem_data in enumerate(tqdm(problems)):
        problem_id = problem_data.get("id", None)
        problem = problem_data.get("problem", "")
        gold_solution = problem_data.get("solution", "")
        gold_answer = problem_data.get("answer", "")
        level = problem_data.get("level", None)
        try:
            start_time = time.time()
            messages = get_no_thinking_messages(model_name, problem)
            if "dapo" in model_name:
                outputs = llm.generate(messages, sampling_params)
            else:
                text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                outputs = llm.generate([text], sampling_params)

            # Get all generated solutions
            solutions = [output.text for output in outputs[0].outputs]
            # Select the shortest solution
            solution = min(solutions, key=len)
            solution_time = time.time() - start_time
            solution_tokens = count_tokens(solution, tokenizer)
            prediction = parse(solution)
            gold = parse(gold_answer)
            metrics = {
                "solution_time": solution_time,
                "solution_tokens": solution_tokens,
                "verify": verify(gold, prediction),
                "all_solutions": solutions  # Store all solutions for analysis
            }
            result = {
                "problem_id": problem_id,
                "problem": problem,
                "gold_solution": gold_solution,
                "gold_answer": gold_answer,
                "level": level,
                "model_response": solution,
                "metrics": metrics
            }
        except Exception as e:
            logger.error(f"Solution failed on problem {problem_id}: {e}")
            result = {
                "problem_id": problem_id,
                "problem": problem,
                "gold_solution": gold_solution,
                "gold_answer": gold_answer,
                "level": level,
                "error": str(e)
            }

        all_results.append(result)

    output_path = f"{output_dir}/{dataset_name}/{model_name}_no_thinking_baseline.jsonl"


    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    logger.info(f"Saved results to {output_path}")





def main():

    set_seed()

    # llm_q = openai.OpenAI(
    #     api_key="distill-qwen2.5-32b",  # vLLM doesn't check this
    #     base_url="http://localhost:8000/v1"
    # )

    sketch_model_name = [#"qwen2.5-7b-instruct",
        #"qwen2.5-math-7b-instruct",
        "qwen2.5-32b-instruct"]

    reasoning_model_name = ["dapo-qwen-32b",
        "qwq-32b","dpsk-distill-qwen2.5-32b"]

    sketch_dataset_name_general = [#"qwen2.5-7b-instruct_decomposition_0","qwen2.5-7b-instruct_decomposition_1",
                           #"qwen2.5-math-7b-instruct_decomposition_0","qwen2.5-math-7b-instruct_decomposition_1",
                           "qwen2.5-32b-instruct_decomposition_0","qwen2.5-32b-instruct_decomposition_1"]
    sketch_dataset_name_reasoning = ["qwq-32b_decomposition_0","qwq-32b_decomposition_1"]



    # for dataset_name in ["gsm8k", "math500"]:
    #         generate_problem_decomposition(
    #         dataset_name=dataset_name,
    #         model_name='qwen2.5-32b-instruct',
    #         output_dir=f"./results/sketch_dataset_full",
    #         num_problems=None)


    # for sketch_model in sketch_model_name:
    #     for dataset_name in ["gsm8k", "math500", "aime"]:


    #         # solve_and_save_baseline(
    #         #     dataset_name,
    #         #     model_name=reasoning_model,
    #         #     output_dir=f"./results/reasoning_baseline",
    #         #     num_problems=50,
    #         # )


    #         for sketch_dataset in sketch_dataset_name_general:
    #             solve_and_save_collaboratively(
    #                 dataset_name,
    #                 model_name=sketch_model,
    #                 output_dir=f"./results/reasoning",
    #                 num_problems=50,
    #                 sketch_dataset=sketch_dataset
    #             )
    #         for sketch_dataset in sketch_dataset_name_reasoning:
    #             solve_and_save_collaboratively(
    #                 dataset_name,
    #                 model_name=sketch_model,
    #                 output_dir=f"./results/reasoning",
    #                 num_problems=50,
    #                 sketch_dataset=sketch_dataset
    #             )

    for reasoning_model in reasoning_model_name:
        for dataset_name in ["aime","gsm8k", "math500"]:
            num_problems = 200 if dataset_name == "gsm8k" else 100
            if dataset_name == "aime":
                num_problems = None


            solve_and_save_no_thinking_baseline(
                dataset_name,
                model_name=reasoning_model,
                output_dir=f"./results/reasoning_baseline_no_thinking",
                num_problems=num_problems,
            )

            # solve_and_save_baseline_best_of(
            #     dataset_name,
            #     model_name=reasoning_model,
            #     output_dir=f"./results/reasoning_baseline_half_best_of_5",
            #     num_problems=num_problems,
            #)
            # if not (reasoning_model == "qwq-32b" and dataset_name == "gsm8k"):
            #     solve_and_save_baseline(
            #         dataset_name,
            #         model_name=reasoning_model,
            #         output_dir=f"./results/reasoning_baseline_half",
            #         num_problems=num_problems,
            #     )

            # for sketch_dataset in sketch_dataset_name_general:
            #     solve_and_save_collaboratively(
            #         dataset_name,
            #         model_name=reasoning_model,
            #         output_dir=f"./results/reasoning_half",
            #         num_problems=num_problems,
            #         sketch_dataset=sketch_dataset
            #     )


if __name__ == "__main__":
    main()
