# textgrad/tasks/aime.py
import platformdirs
from .base import Dataset
from datasets import load_dataset

class AIME2024(Dataset):
    def __init__(self, root: str=None, split: str="train", *args, **kwargs):
        """
        AIME 2024, the American Invitational Mathematics Examination dataset from HuggingFace
        
        Args:
            root (string): Root directory of the dataset
            split (string, optional): The dataset split, supports only "train" currently
                                     (since the dataset has only one split)
        """
        if root is None:
            root = platformdirs.user_cache_dir("textgrad")
            
        self.root = root
        self.split = split
        assert split in ["train", "val", "test"]  # For API compatibility
        
        # Load the dataset from HuggingFace
        try:
            full_dataset = load_dataset('HuggingFaceH4/aime_2024', cache_dir=root)
            self.data = full_dataset["train"]
        except Exception as e:
            raise RuntimeError(f"Failed to load AIME 2024 dataset: {e}")
        
        # Create val/test splits for compatibility with the optimization code
        import numpy as np
        np.random.seed(1)  # For reproducibility
        indices = np.random.permutation(len(self.data))
        
        if split == "train":
            # Use 60% for training
            split_indices = indices[:int(0.6 * len(indices))]
        elif split == "val":
            # Use 20% for validation
            split_indices = indices[int(0.6 * len(indices)):int(0.8 * len(indices))]
        else:  # test
            # Use 20% for testing
            split_indices = indices[int(0.8 * len(indices)):]
            
        self.data = self.data.select(split_indices)
        
        # Default task description
        self._task_description = "You will answer a mathemetical reasoning question. Think step by step and return the answer. "
        
    def __getitem__(self, index):
        row = self.data[index]
        question = row["problem"]
        answer = row["answer"]
        return question, answer

    def __len__(self):
        return len(self.data)

    def get_task_description(self):
        return self._task_description