#!/usr/bin/env python3
# -*- coding: utf-8 -*-


from datasets import load_dataset
from dataclasses import dataclass
from typing import List
import json


@dataclass
class AIMEItem:
    subset: str          # e.g. "AIME2025-I"
    item_id: int         # 1..15 within its subset
    global_id: int       # 1..30 overall
    question: str
    answer: str          # gold integer answer


@dataclass
class LCBItem:
    question_id: str
    question_title: str
    platform: str        # e.g. atcoder / leetcode / codeforces
    difficulty: str      # e.g. easy / medium / hard
    question_content: str
    prompt: str
    starter_code: str
    public_test_cases: str   # JSON string
    private_test_cases: str  # encoded string
    metadata: str            # JSON string


@dataclass
class GPQAItem:
    id: int
    question: str
    correct_answer: str          # single letter A/B/C/D
    answer_choices: List[str]    # list of choices
    prompt: str


@dataclass
class Math500Item:
    id: int
    problem: str         # LaTeX problem statement
    answer: str
    level: int
    prompt: str


@dataclass
class HiToMItem:
    id: int
    answer: str
    prompt: str
    path: str
    tell_type: str       # e.g. "Tell" / "No_Tell"
    length: str          # e.g. "length_1"
    order_index: int     # 0..5


def load_aime_dataset(version: str = "2025") -> List[AIMEItem]:
    """Load AIME dataset by version (2024 or 2025)."""
    items: List[AIMEItem] = []
    gid = 1

    if version == "2025":
        ds_i = load_dataset("",
                            name="AIME2025-I", split="test")
        ds_ii = load_dataset("",
                             name="AIME2025-II", split="test")

        for idx, row in enumerate(ds_i):
            items.append(
                AIMEItem(
                    subset="AIME2025-I",
                    item_id=idx + 1,
                    global_id=gid,
                    question=row["question"],
                    answer=str(row["answer"]).strip(),
                )
            )
            gid += 1

        for idx, row in enumerate(ds_ii):
            items.append(
                AIMEItem(
                    subset="AIME2025-II",
                    item_id=idx + 1,
                    global_id=gid,
                    question=row["question"],
                    answer=str(row["answer"]).strip(),
                )
            )
            gid += 1

    elif version == "2024":
        # Load AIME 2024 set
        ds = load_dataset("", split="train")

        for row in ds:
            id_str = row["ID"]
            parts = id_str.split("-")
            if len(parts) == 3:
                year, subset_suffix, item_id_str = parts
                subset = f"AIME{year}-{subset_suffix}"
                item_id = int(item_id_str)
            else:
                subset = "AIME2024-Unknown"
                item_id = gid

            items.append(
                AIMEItem(
                    subset=subset,
                    item_id=item_id,
                    global_id=gid,
                    question=row["Problem"],
                    answer=str(row["Answer"]).strip(),
                )
            )
            gid += 1

    else:
        raise ValueError(f"Unsupported AIME version: {
                         version}. Use '2024' or '2025'")

    return items


def load_lcb_dataset(jsonl_path: str = "lcb_v6_with_prompts.jsonl") -> List[LCBItem]:
    """Load LiveCodeBench dataset from JSONL with prompts & metadata."""
    items: List[LCBItem] = []

    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data = json.loads(line)
                items.append(
                    LCBItem(
                        question_id=data['question_id'],
                        question_title=data['question_title'],
                        platform=data['platform'],
                        difficulty=data['difficulty'],
                        question_content=data['question_content'],
                        prompt=data['prompt'],
                        starter_code=data['starter_code'],
                        public_test_cases=data['public_test_cases'],
                        private_test_cases=data['private_test_cases'],
                        metadata=data['metadata']
                    )
                )

    return items


def load_hitom_dataset(jsonl_path: str = "hitom_dataset.jsonl") -> List[HiToMItem]:
    """Load Hi-ToM dataset."""
    items: List[HiToMItem] = []
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data = json.loads(line)
                items.append(
                    HiToMItem(
                        id=data['id'],
                        answer=data.get('answer', ''),
                        prompt=data.get('prompt', ''),
                        path=data.get('path', ''),
                        tell_type=data.get('tell_type', ''),
                        length=data.get('length', ''),
                        order_index=data.get('order_index', -1),
                    )
                )
    return items


def load_gpqa_dataset(jsonl_path: str = "gpqa_dataset.jsonl") -> List[GPQAItem]:
    """Load GPQA dataset (JSONL)."""
    items: List[GPQAItem] = []

    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data = json.loads(line)
                items.append(
                    GPQAItem(
                        id=data['id'],
                        question=data['question'],
                        correct_answer=data['answer'],
                        answer_choices=data['answer_choices'],
                        prompt=data['prompt']
                    )
                )

    return items


def load_math500_dataset(jsonl_path: str = "math500_level5.jsonl") -> List[Math500Item]:
    """Load MATH-500 level-5 dataset."""
    items: List[Math500Item] = []

    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data = json.loads(line)
                items.append(
                    Math500Item(
                        id=data['id'],
                        problem=data.get('problem', ''),
                        answer=data.get('answer', ''),
                        level=data.get('level', 0),
                        prompt=data.get('prompt', '')
                    )
                )

    return items


def get_dataset_info(version: str = "2025") -> dict:
    """Return basic metadata summary for an AIME version."""
    items = load_aime_dataset(version)

    total_problems = len(items)
    subsets = {}
    for item in items:
        if item.subset not in subsets:
            subsets[item.subset] = 0
        subsets[item.subset] += 1
    return {
        "version": version,
        "total_problems": total_problems,
        "subsets": subsets,
        "items": items
    }


if __name__ == "__main__":
    # Simple manual test
    print("Testing AIME dataset loading...")

    # AIME 2025
    print("\n=== AIME 2025 ===")
    info_2025 = get_dataset_info("2025")
    print(f"Total problems: {info_2025['total_problems']}")
    print(f"Subsets: {info_2025['subsets']}")

    # AIME 2024
    print("\n=== AIME 2024 ===")
    info_2024 = get_dataset_info("2024")
    print(f"Total problems: {info_2024['total_problems']}")
    print(f"Subsets: {info_2024['subsets']}")

    # Show a few sample problems
    print("\n=== Sample problems ===")
    for i, item in enumerate(info_2025['items'][:3]):
        print(
            f"Problem {i+1}: {item.subset} Q{item.item_id} (G{item.global_id})")
        print(f"Question: {item.question[:100]}...")
        print(f"Answer: {item.answer}")
        print()
