from functools import partial
from typing import Any

from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR, SequentialLR
from temporal_task_planner.trainer.dataset import (
    DishwasherArrangeDataset,
    PromptSituationPickPlace,
    SessionPreferenceDataset,
    pad_fn,
    prompt_pad_fn,
    input_pad_fn, 
    preference_classifier_pad_fn
)
from temporal_task_planner.trainer.logger import SingleModelLogger, DualModelLogger, SessionPreferenceClassifierLogger

class ModuleWrapper:
    def __init__(self, method_name, **kwargs) -> None:
        callable_method = globals()[method_name]
        self._partial = partial(callable_method, **kwargs)

    def __call__(self, *args: Any, **kwds: Any) -> Any:
        module = self._partial(*args, **kwds)
        return module
