import random
from typing import List, Optional, Tuple

from datasets import load_dataset

from core.wrappers.example import Example


def _subsample_list(items: List[Example], ratio: float, *, seed: int) -> List[Example]:
    """Deterministically subsample a list to the given ratio."""
    if items is None:
        return []
    n = len(items)
    if n == 0:
        return []
    ratio = float(ratio)
    if ratio >= 1.0:
        return items
    if ratio <= 0.0:
        # Keep at least 1 item if non-empty to avoid empty splits in training.
        return items[:1]
    k = int(n * ratio)
    if k <= 0:
        k = 1
    rng = random.Random(int(seed))
    idx = list(range(n))
    rng.shuffle(idx)
    chosen = set(idx[:k])
    return [x for i, x in enumerate(items) if i in chosen]


def dataset_engine(
    train_size: int = 1000, seed: int = 42, subsample_ratio: Optional[float] = None, **kwargs
) -> Tuple[List[Example], List[Example], List[Example]]:
    """
    Load and prepare the BigCodeBench dataset as train/val/test splits.

    Args:
        train_size (int): Number of examples to use for training. Defaults to 1000.
        seed (int): Random seed for reproducibility. Defaults to 42.

    Returns:
        Tuple[List[Example], List[Example], List[Example]]: trainset, valset, testset
    """
    raw_data = load_dataset("bigcode/bigcodebench", split="v0.1.3")

    # Assign split tags
    splits = ['train'] * train_size + ['test'] * (len(raw_data) - train_size)
    random.seed(seed)
    random.shuffle(splits)

    # Build Examples
    examples = [
        Example(
            question=ex['instruct_prompt'],
            code=ex['code_prompt'] + ex['canonical_solution'],
            unit_tests=ex['test'],
            task_id=ex['task_id'],
            entry_point='task_func'
        ).with_inputs("question")
        for ex in raw_data
    ]

    # Split into sets
    train_val = [ex for ex, tag in zip(examples, splits) if tag == 'train']
    testset = [ex for ex, tag in zip(examples, splits) if tag == 'test']

    split_idx = int(len(train_val) * 0.95) if len(train_val) >= 20 else len(train_val) - 1
    trainset = train_val[:split_idx]
    valset = train_val[split_idx:]

    # Optional subsampling (e.g., paper setting: use half for train/val/test).
    # Accept a few common alias keys via kwargs to make CLI/yaml wiring easy.
    if subsample_ratio is None:
        for alias in ("ratio", "fraction", "subsample", "subsample_fraction"):
            if alias in kwargs and kwargs[alias] is not None:
                subsample_ratio = float(kwargs[alias])
                break
    if subsample_ratio is not None:
        # Use different seeds per split for stability.
        trainset = _subsample_list(trainset, float(subsample_ratio), seed=seed + 101)
        # valset = _subsample_list(valset, float(subsample_ratio), seed=seed + 202)
        testset = _subsample_list(testset, float(subsample_ratio), seed=seed + 303)

    return trainset, valset, testset


if __name__ == "__main__":
    trainset, valset, testset = dataset_engine()
    print(f"Loaded {len(trainset)} training, {len(valset)} validation, and {len(testset)} test examples.")
    print(trainset[0])
    # => 
    # Example({
    # 'question': 'Calculates the average of the sums of absolute differences between each pair of consecutive numbers ...', 
    # 'unit_tests': "import unittest\nfrom unittest.mock import patch\nfrom random import seed, shuffle\nimport itertools\nclass TestCases(unittest.TestCase):\n    ...", 
    # 'task_id': 'BigCodeBench/0', 
    # 'entry_point': 'task_func'
    # }) (input_keys={'question'})
