import os
import json
import torch
import numpy as np
import transformers
import random
import re


dataset_paths = {
    "math_train": "./data/dataset/math_train.json",
    "math_test": "./data/dataset/math_test.json",
    "gsm8k": "./data/dataset/gsm8k.json",
    "gaokao": "./data/dataset/gaokao.json",
    "amc23": "./data/dataset/amc23.json",
    "aime2024": "./data/dataset/aime2024.json",
    "aime2025": "./data/dataset/aime2025.json",
}


def try_extract(output, pattern):
    matches = re.findall(pattern, output, re.DOTALL)
    answers = [match.strip() for match in matches]
    if len(answers) > 0:
        return answers[-1]
    else:
        return "None"


def extract_answer(output):
    answers = []
    for piece in output.split('boxed{')[1:]:
        n = 0
        for i in range(len(piece)):
            if piece[i] == '{':
                n += 1
            elif piece[i] == '}':
                n -= 1
                if n < 0:
                    if i + 1 < len(piece) and piece[i + 1] == '%':
                        answers.append(piece[: i + 1])
                    else:
                        answers.append(piece[:i])
                    break
    if len(answers) > 0:
        # return answers[0]
        return answers[-1]
    else:
        return "None"


def set_seed(seed: int):
    """Sets the seed for reproducibility across various libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)


chunk_starters = [
    "Wait",
    "But wait",
    "Alternatively",
    "Is there another way to think about this?",
    "But let me double-check",
    "But hold on",
    "Let me double-check",
    "To ensure",
    "Is there any other",
    "Is there any chance",
    "But just to be thorough",
    "Just to recap",
    "Just to make sure",
    "Just to ensure",
    "But just to recap",
    "Another way to",
    "I guess another way",
    "Just to think about"
    "But, just to think about"
    "Is there any alternative"
    "Is there another way"
    "Let me just double-check"
    "But just to make sure",
    "Just to be thorough",
    "Hold on",
    "Let me think if there's another",
    "Is there a way this could go wrong?",
    "I wonder if there's a more general principle",
    "Let me recap",
    "But just to double-check",
    "Hmm, is there",
]


def sample_indices(data, dataset, num_inst):
    indices = list(range(len(data))) 
    sampled_indices = indices

    if num_inst != -1:
        output_path = f"./indices/{dataset}_num_inst={num_inst}.json"
        if os.path.exists(output_path):
            with open(output_path, "r", encoding="utf-8") as f:
                sampled_indices = json.load(f)
        else:
            indices = list(range(len(data)))
            sampled_indices = []

            existing_files = [f for f in os.listdir("./indices") if re.match(f"{dataset}_num_inst=\\d+\\.json", f)]

            available_prior_files = []
            for fname in existing_files:
                match = re.search(rf"{dataset}_num_inst=(\d+)\.json", fname)
                if match:
                    prior_num = int(match.group(1))
                    if prior_num < num_inst:
                        available_prior_files.append((prior_num, fname))

            if available_prior_files:
                prior_num, prior_path = max(available_prior_files, key=lambda x: x[0])
                with open(prior_path, "r", encoding="utf-8") as f:
                    prior_sampled = json.load(f)
                sampled_indices.extend(prior_sampled)
                print(f"Loaded {len(prior_sampled)} prior sampled indices from {prior_path}")
            else:
                prior_sampled = []

            remaining = num_inst - len(sampled_indices)
            if remaining > 0:
                remaining_pool = list(set(indices) - set(sampled_indices))
                random.shuffle(remaining_pool)
                additional_sampled = remaining_pool[:remaining]
                sampled_indices.extend(additional_sampled)
                print(f"Randomly sampled additional {len(additional_sampled)} indices.")

            if not os.path.exists(output_path):
                with open(output_path, "w", encoding="utf-8") as f:
                    json.dump(sampled_indices, f)
                print(f"Sampled indices saved to {output_path}")
            else:
                print(f"Sampled file already exists at {output_path}, not overwritten.")

    return sampled_indices