from datasets import Dataset, load_dataset
from typing import Optional
import random


def preprocess_dataset(
    dataset_name: str = "gsm8k",
    additional_prompt: str = (
        "Please reason step by step, and put your final answer within \\boxed{{}}.\n"
    ),
    n: int = 1000,
    max_answer_chars: Optional[int] = None,
    difficulty: Optional[str] = None,
    seed: Optional[int] = 42,
) -> Dataset:
    
    # Dataset loading configuration
    dataset_configs = {
        'deepmath': ("zwhe99/DeepMath-103K", "train", None),
        'deepmath_test': ("zwhe99/DeepMath-103K", "train", None),
        "aime_train": ("gneubig/aime-1983-2024", "train", None),
        "math_test": ("nlile/hendrycks-MATH-benchmark", "test", None),
        "olympiadbench": ("math-ai/olympiadbench", "test", None),
        "omni_math": ("KbsdJames/Omni-MATH", "test", None),
        "amc12_22_24": ("rulins/amc12_22-24", "train", None),
    }
    
    # Load dataset based on configuration
    if dataset_name in dataset_configs:
        repo_id, split, _ = dataset_configs[dataset_name]
        
        # Handle special cases with additional config parameters
        if dataset_name == "olympiadbench":
            dataset: Dataset = load_dataset(repo_id, split=split)
            # Filter for Text-only modality
            dataset = dataset.filter(lambda x: x.get("modality", "") == "Text-only")
        else:
            dataset: Dataset = load_dataset(repo_id, split=split)
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")

    if "deepmath" in dataset_name:
        random.seed(seed)
        indices = random.sample(range(len(dataset)), len(dataset))
        train_indices = indices[:len(dataset)//2]
        test_indices = indices[len(dataset)//2:]
        if "test" not in dataset_name:
            use_indices = train_indices[:min(n*10, len(train_indices))]
        elif "test" in dataset_name:
            use_indices = test_indices[:min(n*10, len(test_indices))]
        dataset = dataset.select(use_indices)

    # Field mapping configuration for supported datasets
    # Define field mappings for different dataset types
    field_mappings = {
        "deepmath": {
            "question": "question",
            "answer": "final_answer",
            "difficulty": "difficulty",
        },
        "problem_answer": {  # Common pattern for many math datasets
            "question": "problem",
            "answer": "answer",
        },
        "aime_train": {
            "question": "Question",
            "answer": "Answer",
        },
        "olympiadbench": {
            "question": "question",
            "answer": "final_answer",
        },
    }

    # Determine which mapping to use
    mapping_key = None
    if "deepmath" in dataset_name:
        mapping_key = "deepmath"
    elif dataset_name in ["omni_math", "amc12_22_24"] or "math" in dataset_name:
        mapping_key = "problem_answer"
    elif dataset_name == "olympiadbench":
        mapping_key = "olympiadbench"
    elif dataset_name == "aime_train":
        mapping_key = "aime_train"

    # Apply field mapping if a mapping is defined
    if mapping_key and mapping_key in field_mappings:
        mapping = field_mappings[mapping_key]
        dataset = dataset.map(
            lambda x: {
                target_field: x[source_field]
                for target_field, source_field in mapping.items()
                if source_field in x
            }
        )
    
    # Special processing for specific datasets
    if dataset_name == "olympiadbench":
        # Convert final_answer list to string
        dataset = dataset.map(
            lambda x: {
                "answer": x["answer"][0] if isinstance(x["answer"], list) and len(x["answer"]) > 0 else x["answer"]
            }
        )

    
    # Optionally filter out examples with very long answers
    if max_answer_chars is not None:
        def _answer_within_limit(example, max_chars: int) -> bool:
            ans = example.get("answer", None)
            if ans is None:
                return True
            return len(ans) <= max_chars if isinstance(ans, str) else len(str(ans)) <= max_chars

        dataset = dataset.filter(_answer_within_limit, fn_kwargs={"max_chars": max_answer_chars})


    if difficulty is not None:
        raise ValueError(f"Difficulty filtering is not supported for any of the remaining datasets")
            
    random.seed(seed)
    indices = random.sample(range(len(dataset)), min(n, len(dataset)))
    dataset = dataset.select(indices)
    

    dataset = dataset.map(
        lambda x: {
            "prompt": [
                {"role": "user", "content": x["question"] + "\n" + additional_prompt},
            ],
        }
    )

    return dataset