"""Dataset construction for math domain.

Generates training and validation datasets for math problem solving with advisor feedback.
Uses Math500 dataset as the source for mathematical problems.
The advisor provides guidance directly without seeing an initial attempt.

Example usage:
    python advisor_models/math/construct_dataset.py \
        --output_dir data/math \
        --train_ratio 0.8
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
import os
import random
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any
from tqdm import tqdm

import datasets

from config import (
    ADVISOR_SYSTEM_PROMPT,
    ADVISOR_INSTRUCTION,
)


def build_advisor_prompt(problem: str) -> List[Dict[str, str]]:
    """Build the advisor prompt that the model will receive."""
    user_content = ADVISOR_INSTRUCTION.format(problem=problem)

    return [
        {"role": "system", "content": ADVISOR_SYSTEM_PROMPT},
        {"role": "user", "content": user_content},
    ]


def load_math500_data() -> List[Dict[str, Any]]:
    """Load Math500 dataset from JSONL file."""
    current_dir = Path(__file__).parent
    math500_path = current_dir / "math500.jsonl"

    if not os.path.exists(math500_path):
        raise FileNotFoundError(f"Math500 dataset not found at {math500_path}")

    problems = []
    with open(math500_path, "r") as f:
        for line in f:
            problems.append(json.loads(line.strip()))

    print(f"Loaded {len(problems)} problems from Math500 dataset")
    return problems


def split_math500_data(
    problems: List[Dict[str, Any]], train_ratio: float = 0.8
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """Split Math500 data into train and validation sets."""
    # Shuffle the problems
    shuffled_problems = problems.copy()
    random.shuffle(shuffled_problems)

    # Split into train/val
    split_idx = int(len(shuffled_problems) * train_ratio)
    train_problems = shuffled_problems[:split_idx]
    val_problems = shuffled_problems[split_idx:]

    print(
        f"Split into {len(train_problems)} train and {len(val_problems)} validation problems"
    )
    return train_problems, val_problems


def process_math_problem(problem_data: Dict[str, Any]) -> Dict[str, Any]:
    """Process a single math problem to create a training example."""
    problem = problem_data["problem"]
    answer = problem_data["answer"]

    # Build the advisor prompt
    prompt = build_advisor_prompt(problem)

    return {
        "prompt": prompt,
        "env_class": "math",
        "reward_spec": {"ground_truth": answer},
        "original_question": problem,
    }


def process_problems(
    problems: List[Dict[str, Any]], max_workers: int = 12
) -> List[Dict[str, Any]]:
    """Process math problems in parallel to create training examples."""
    rows = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(process_math_problem, problem) for problem in problems
        ]
        for future in tqdm(futures, desc="Processing problems"):
            try:
                row = future.result()
                rows.append(row)
            except Exception as e:
                print(f"Error processing problem: {e}")
                continue

    return rows


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Construct math dataset")
    parser.add_argument(
        "--suffix",
        type=str,
        default="",
        help="Suffix to add to output files",
    )
    parser.add_argument(
        "--train_ratio",
        type=float,
        default=0.8,
        help="Ratio of data to use for training (rest goes to validation)",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="data/math",
        help="Output directory for dataset",
    )
    parser.add_argument(
        "--max_workers",
        type=int,
        default=12,
        help="Number of parallel workers for processing",
    )

    args = parser.parse_args()
    random.seed(42)

    print("Loading Math500 dataset...")
    problems = load_math500_data()

    print("Splitting data into train/validation...")
    train_problems, val_problems = split_math500_data(problems, args.train_ratio)

    print(
        f"Processing {len(train_problems)} training and {len(val_problems)} validation problems..."
    )

    # Process problems to create training examples
    train_rows = process_problems(train_problems, args.max_workers)
    val_rows = process_problems(val_problems, args.max_workers)

    # Write to parquet
    os.makedirs(args.output_dir, exist_ok=True)
    suffix = f"_{args.suffix}" if args.suffix else ""
    train_parquet_path = os.path.join(args.output_dir, f"train{suffix}.parquet")
    val_parquet_path = os.path.join(args.output_dir, f"validation{suffix}.parquet")

    datasets.Dataset.from_list(train_rows).to_parquet(train_parquet_path)
    datasets.Dataset.from_list(val_rows).to_parquet(val_parquet_path)

    print(
        f"Wrote {len(train_rows)} training and {len(val_rows)} validation examples to {args.output_dir}"
    )
