from itertools import product
from patching_gemma import logger

from patching_gemma.tasks.tasks_gemma import FineGrainedTypesTask
from patching_gemma.tasks.tasks_phi2 import FineGrainedTypesTaskPhi2
from patching_gemma.tasks.tasks_smollm import FineGrainedTypesTaskSmollm
from patching_gemma.tasks.tasks_different_corruptions import FineGrainedTypesDifferentCorruptionsTask

from patching_gemma.tasks.present_past.present_past_ambiguous import PresentPastAmbiguousTask
from patching_gemma.tasks.present_past.present_past_ambiguous_different_corruption import PresentPastAmbiguousDifferentCorruptionsTask
from patching_gemma.tasks.present_past.present_past_ambiguous_different_corruption_separate_full_corruption import PresentPastAmbiguousDifferentCorruptionsSeparateFullCorruptionTask

ALL_TASKS = [
    FineGrainedTypesTask, FineGrainedTypesTaskPhi2, FineGrainedTypesTaskSmollm, FineGrainedTypesDifferentCorruptionsTask,
    PresentPastAmbiguousTask, PresentPastAmbiguousDifferentCorruptionsTask,
    PresentPastAmbiguousDifferentCorruptionsSeparateFullCorruptionTask
]

class InstantiateTask:
    def __init__(self, task_class, kwargs=None):
        self.task_class = task_class
        self.kwargs = kwargs
    def __call__(self, *args, **kwargs):
        if kwargs is not None:
            return self.task_class(**self.kwargs)
        return self.task_class()

NAME_TO_TASK = {}
for task_class in ALL_TASKS:
    if hasattr(task_class, "ALLOWED"):
        all_possible_arg_combinations = list(
            product(*[range(len(task_class.ALLOWED[arg])) for arg in task_class.ALLOWED])
        )
        for combination in all_possible_arg_combinations:
            logger.debug(combination, [len(task_class.ALLOWED[arg]) for arg in task_class.ALLOWED])
            args = {
                arg: task_class.ALLOWED[arg][value_num]
                for arg, value_num in zip(task_class.ALLOWED, combination)
            }
            NAME_TO_TASK[task_class.get_name(**args)] = InstantiateTask(task_class, args)
    else:
        NAME_TO_TASK[task_class.get_name()] = InstantiateTask(task_class)