from inspect_ai.dataset import MemoryDataset, Sample
from itertools import product


def generate_combinations(d: dict[str, dict[str, any]]):
    # Extract the items from each nested dictionary
    items = [
        [
            {'category': category, 'level': level, 'value': value}
            for level, value in category_levels.items()
        ]
        for category, category_levels in d.items()
    ]
    # Use product to generate all combinations
    combinations = product(*items)
    # Convert each combination to a set and put it in a list
    return combinations


def create_difficulty_dataset(sample: Sample, difficulty_map: dict[str, dict[str, any]] | None) -> MemoryDataset:
    """
    Create a dataset with a single sample and a difficulty map. The difficulty map is a dictionary where the keys are
    the difficulty category and the values are maps from difficulty levels within that category to the string to
    insert into the prompt.

    The sample's input field should be a format string with placeholders for each of the difficulty categories. For
    example, if the sample's input is "Question: ... \n\n Hint: {hint}", then the difficulty map should have a key
    "hint" with a value that is a dictionary mapping difficulty levels to strings to insert into the prompt.

    We combine every combination of difficulty levels to create a new sample for each combination. i.e. if there are
    two difficulty categories (hint and explanation) with three levels of hint difficulty and three levels of
    explanation difficulty, we will create 9 samples.

    :param sample: The sample to base the dataset on
    :param difficulty_map: The difficulty map
    :return: The dataset
    """
    # If there are no difficulty levels, return the sample as is
    if difficulty_map is None or len(difficulty_map) == 0:
        return MemoryDataset([sample])

    samples = []

    # create a new sample for each combination of difficulty levels
    for combination in generate_combinations(difficulty_map):
        substitutions = {c['category']: c['value'] for c in combination}

        metadata = sample.metadata.copy() if sample.metadata is not None else {}
        metadata.update({c['category']: c['level'] for c in combination})
        metadata['difficulty'] = {c['category']: c['level'] for c in combination}

        input = sample.input
        # first format anything that is not callable
        input = input.format(**{k: v for k, v in substitutions.items() if not callable(v)}).strip()
        # then format anything that is callable
        for key, value in substitutions.items():
            if callable(value):
                input = value(input)

        id_ = sample.id
        if id_ is not None:
            for category, level in metadata['difficulty'].items():
                id_ += f'_{category}_{level}'

        new_sample = Sample(
            id=id_,
            input=input,
            target=sample.target,
            choices=sample.choices,
            metadata=metadata,
            files=sample.files
        )
        samples.append(new_sample)

    return MemoryDataset(samples)
