import ujson as json
import multiprocessing as mp
import os
import accuracy_utils
import datasets
from tqdm import tqdm
import tyro
from transformers import AutoTokenizer

from math_verify import parse, verify, LatexExtractionConfig
from latex2sympy2_extended import NormalizationConfig
from dataclasses import dataclass
import random
import deepseek_utils


@dataclass
class Args:
    dataset_id: str = None
    path_dir: str = None
    start_shard: int = 0
    end_shard: int = -1
    shard_size: int = 512
    num_responses_per_prompt: int = 16
    aggregate: bool = False

    tokenizer_id: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

    def __post_init__(self):
        dirname = os.path.dirname(self.path_dir)
        self.basename = os.path.basename(self.path_dir)
        self.processed_dir = os.path.join(dirname, f"{self.basename}_processed")
        os.makedirs(self.processed_dir, exist_ok=True)


def load_json_file(file_path):
    """
    Load a single JSON file and return its contents with validation.
    Returns a tuple of (file_path, data, status, error_message)
    """
    result = {
        "file_path": file_path,
        "data": None,
        "status": "unknown",
        "error": None
    }

    try:
        # Check if file exists
        if not os.path.exists(file_path):
            result["status"] = "error"
            result["error"] = "File does not exist"
            return result

        # Check if file is empty
        if os.path.getsize(file_path) == 0:
            result["status"] = "empty"
            result["error"] = "File is empty"
            return result

        # Try to load and parse JSON
        with open(file_path, 'r') as f:
            data = json.load(f)
            result["data"] = data
            result["status"] = "valid"
            return result

    except json.JSONDecodeError as e:
        result["status"] = "invalid_json"
        result["error"] = f"Invalid JSON: {str(e)}"
        return result
    except Exception as e:
        result["status"] = "error"
        result["error"] = f"Error loading file: {str(e)}"
        return result


def load_json_files_parallel(file_paths, num_processes=None):
    """
    Load multiple JSON files in parallel.

    Args:
        file_paths: List of paths to JSON files
        num_processes: Number of processes to use (defaults to CPU count)

    Returns:
        List of loaded JSON contents in the same order as file_paths
    """
    # If num_processes is not specified, use the number of CPU cores
    if num_processes is None:
        num_processes = mp.cpu_count()
    print(f"Parallel loading files with {num_processes} processes: ", file_paths)

    # Create a pool of workers
    with mp.Pool(processes=num_processes) as pool:
        # Map the load_json_file function to each file path
        results = pool.map(load_json_file, file_paths)

    return results


def equivalence_relation(expr1, expr2, reverse=False):
    try:
        # parse1 = parse(f"${expr1}$",
        #     extraction_config=[
        #     LatexExtractionConfig(
        #         normalization_config=NormalizationConfig(
        #             basic_latex=True,
        #             units=True,
        #             malformed_operators=False,
        #             nits=False,
        #             boxed="all",
        #             equations=False,
        #         ),
        #         boxed_match_priority=0,
        #         try_extract_without_anchor=False,
        #     )],
        # )
        # parse2 = parse(f"${expr2}$",
        #     extraction_config=[
        #     LatexExtractionConfig(
        #         normalization_config=NormalizationConfig(
        #             basic_latex=True,
        #             units=True,
        #             malformed_operators=False,
        #             nits=False,
        #             boxed="all",
        #             equations=False,
        #         ),
        #         boxed_match_priority=0,
        #         try_extract_without_anchor=False,
        #     )],
        # )
        if reverse:
            return verify(parse(f"${expr2}$", ), parse(f"${expr1}$"))
        else:
            return verify(parse(f"${expr1}$"), parse(f"${expr2}$"))
    except:
        print(expr1, expr2)
        breakpoint()


def compute_majority_vote(partitions, gt_answer):
    longest_partition = max(partitions, key=len)
    accuracy = 0
    num_longest = 0
    for partition in partitions:
        if len(partition) == len(longest_partition):
            accuracy += equivalence_relation(partition[0], gt_answer)
            num_longest += 1
    return accuracy / num_longest


def main(args):
    # tokenize the problem and generated_response
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
    if not args.aggregate:
        shards = []
        for path in os.listdir(args.path_dir):
            if path.endswith(".json"):
                shard_idx = int(path.split('.')[0])
                shards.append(shard_idx)
        shards.sort()
        if args.end_shard == -1:
            args.end_shard = max(shards)+1
        print(f"There are {len(shards)} shards in total: {shards}")

        # shards should be consecutive
        relevant_shards = list(range(args.start_shard, args.end_shard))
        for shard in relevant_shards:
            assert shard in shards, f"Shard {shard} not found in {args.path_dir}"

        print("Loading shards:", relevant_shards)
        files = [os.path.join(args.path_dir, f"{shard}.json") for shard in relevant_shards]

        # Load all files in parallel
        generation_data = load_json_files_parallel(files, num_processes=16)
        # all the shards should have 'valid' status
        for shard_idx in range(args.start_shard, args.end_shard):
            status = generation_data[shard_idx-args.start_shard]["status"]
            assert status == "valid", f"{shard_idx} has invalid status {status}"

        dataset = datasets.load_dataset(args.dataset_id, split='train')
        N = len(dataset)
        for shard_idx in range(args.start_shard, args.end_shard):
            save_shard_path = os.path.join(args.processed_dir, f"{shard_idx}.json")
            if os.path.exists(save_shard_path):
                print(f"Shard {shard_idx} already processed, skipping...")
                continue

            start_idx = shard_idx * args.shard_size
            end_idx = min((shard_idx + 1) * args.shard_size, N)
            dataset_shard = dataset[start_idx:end_idx]
            generated_shard = generation_data[shard_idx-args.start_shard]["data"]
            assert len(generated_shard) == (end_idx-start_idx), f"Shard {shard_idx}: Expected {end_idx-start_idx} samples, got {len(generated_shard)}"

            new_dataset_shard = []
            for i in tqdm(range(end_idx - start_idx), desc=f"Processing shard {shard_idx}"):
                new_row = {
                    "problem": dataset_shard["problem"][i],
                    "solution": dataset_shard["solution"][i],
                    "answer": dataset_shard["answer"][i],
                    "problem_type": dataset_shard["problem_type"][i],
                    "question_type": dataset_shard["question_type"][i],
                    "source": dataset_shard["source"][i],
                    "uuid": dataset_shard["uuid"][i],
                    "generated_response": [],
                    "generated_answer": [],
                    "processed_answer": [],
                    "reward": [],
                    "reverse_reward": [],
                    "min_length": 1e10,
                    "max_length": 0,
                }
                assert len(generated_shard[i]) == args.num_responses_per_prompt, f"Shard {shard_idx} ({i}): Expected {args.num_responses_per_prompt} responses, got {len(generated_shard[i])}"
                for j in range(args.num_responses_per_prompt):
                    generated_response = generated_shard[i][j]
                    generated_text = tokenizer.decode(generated_response, skip_special_tokens=False)
                    generated_answer = deepseek_utils.remove_thinking_text(generated_text)
                    processed_answer = accuracy_utils.process_sample(generated_answer)

                    gt_answer = new_row["answer"]
                    # order here is important!
                    reward = equivalence_relation(gt_answer, processed_answer)
                    new_row["generated_response"].append(generated_response)
                    new_row["generated_answer"].append(generated_answer)
                    new_row["processed_answer"].append(processed_answer)
                    new_row["reward"].append(reward)

                    reverse_reward = equivalence_relation(gt_answer, processed_answer, reverse=True)
                    new_row["reverse_reward"].append(reverse_reward)

                    new_row["min_length"] = min(new_row["min_length"], len(generated_response))
                    new_row["max_length"] = max(new_row["max_length"], len(generated_response))

                new_row["pass@1"] = sum(new_row["reward"]) / args.num_responses_per_prompt
                new_row["pass@16"] = any(new_row["reward"])
                partitions = accuracy_utils.equivalence_partition(new_row["processed_answer"], equivalence_relation)
                majority_vote = compute_majority_vote(partitions, new_row["answer"])
                new_row["maj@16"] = majority_vote
                new_dataset_shard.append(new_row)

            with open(save_shard_path, 'w') as f:
                json.dump(new_dataset_shard, f)

    else:
        shards = []
        for path in os.listdir(args.processed_dir):
            if path.endswith(".json"):
                shard_idx = int(path.split('.')[0])
                shards.append(shard_idx)
        shards.sort()

        files = [os.path.join(args.processed_dir, f"{shard}.json") for shard in shards]
        json_data = load_json_files_parallel(files)

        data = []
        for shard_idx in range(len(shards)):
            assert json_data[shard_idx]["status"] == "valid", f"{shard_idx} has invalid status {json_data[shard_idx]['status']}"
            data.extend(json_data[shard_idx]["data"])
        N = len(data)
        print(f"Loaded {N} samples")

        hf_dataset = {
            "problem": [],
            "solution": [],
            "answer": [],
            "problem_type": [],
            "question_type": [],
            "source": [],
            "uuid": [],
            "processed_answer": [],
            "reward": [],
            "min_length": [],
            "max_length": [],
            "pass@1": [],
            "pass@16": [],
            "cons@16": [],
            "roll_in_ids": [],
            "roll_outs_ids": [],
        }
        for i in tqdm(range(N), desc="Converting to HF"):
            hf_dataset["problem"].append(data[i]["problem"])
            hf_dataset["solution"].append(data[i]["solution"])
            hf_dataset["answer"].append(data[i]["answer"])
            hf_dataset["problem_type"].append(data[i]["problem_type"])
            hf_dataset["question_type"].append(data[i]["question_type"])
            hf_dataset["source"].append(data[i]["source"])
            hf_dataset["uuid"].append(data[i]["uuid"])
            hf_dataset["processed_answer"].append(data[i]["processed_answer"])
            or_reward = [r or rr for r, rr in zip(data[i]["reward"], data[i]["reverse_reward"], strict=True)]
            hf_dataset["reward"].append(or_reward)
            hf_dataset["min_length"].append(data[i]["min_length"])
            hf_dataset["max_length"].append(data[i]["max_length"])
            hf_dataset["pass@1"].append(data[i]["pass@1"])
            hf_dataset["pass@16"].append(data[i]["pass@16"])
            hf_dataset["cons@16"].append(data[i]["maj@16"])

            roll_in = deepseek_utils.format_roll_in(data[i]["problem"])
            roll_in_ids = tokenizer.encode(roll_in, add_special_tokens=False)
            hf_dataset["roll_in_ids"].append(roll_in_ids)
            roll_out_ids = data[i].pop("generated_response")
            for j in range(len(roll_out_ids)):
                assert len(roll_out_ids[j]) <= 16384, f"Roll out {j} is too long: {len(roll_out_ids[j])}"
            hf_dataset["roll_outs_ids"].append(roll_out_ids)
        del data

        pass_1 = sum(hf_dataset["pass@1"]) / len(hf_dataset["pass@1"])
        pass_16 = sum(hf_dataset["pass@16"]) / len(hf_dataset["pass@16"])
        cons_16 = sum(hf_dataset["cons@16"]) / len(hf_dataset["cons@16"])
        print(f"pass@1: {pass_1}")
        print(f"pass@16: {pass_16}")
        print(f"cons@16: {cons_16}")

        # split off some for validation and test
        val_size = test_size = 1000
        random.seed(1337)
        shuffled_indices = random.sample(range(N), N)
        val_indices = shuffled_indices[:val_size]
        test_indices = shuffled_indices[val_size:val_size+test_size]
        train_indices = shuffled_indices[val_size+test_size:]

        train_dataset = {k: [v[i] for i in train_indices] for k, v in hf_dataset.items()}
        val_dataset = {k: [v[i] for i in val_indices] for k, v in hf_dataset.items()}
        test_dataset = {k: [v[i] for i in test_indices] for k, v in hf_dataset.items()}
        hf_dataset = datasets.DatasetDict({
            "train": datasets.Dataset.from_dict(train_dataset),
            "validation": datasets.Dataset.from_dict(val_dataset),
            "test": datasets.Dataset.from_dict(test_dataset),
        })
        new_dataset_id = f"{args.dataset_id}_{args.basename}_tokenized"
        # create validation and test splits
        print("Uploading to HF")
        hf_dataset.push_to_hub(new_dataset_id)


if __name__ == "__main__":
    args = tyro.cli(Args)
    main(args)