import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional

import torch
from torch import nn

from fusion_bench import TorchModelType

log = logging.getLogger(__name__)


class ModulatedModel(nn.Module, Generic[TorchModelType]):
    """
    A model wrapper that uses task-specific modulators to adapt a shared backbone
    for different tasks.

    The model maintains a shared backbone and task-specific modulators. During forward pass,
    the appropriate modulator is applied based on the current task.
    """

    _current_task: Optional[str] = None

    def __init__(
        self,
        backbone: TorchModelType,
        modulators: Dict[str, "TaskModulator[TorchModelType]"],
    ):
        super().__init__()
        self.backbone = backbone
        self.modulators = nn.ModuleDict(modulators)

    def add_modulator(self, task_name: str, modulator: "TaskModulator[TorchModelType]"):
        """Add a new task-specific modulator."""
        if task_name in self.modulators:
            raise ValueError(f"Modulator for task '{task_name}' already exists.")
        self.modulators[task_name] = modulator

    def remove_modulator(self, task_name: str):
        """Remove an existing task-specific modulator."""
        if task_name not in self.modulators:
            raise ValueError(f"Modulator for task '{task_name}' does not exist.")
        if self._current_task == task_name:
            log.warning(
                f"Removing modulator for current task '{task_name}'. "
                "This will make unset the current task unpredictable."
            )
        del self.modulators[task_name]

    def set_task(self, task_name: str):
        """Set the current task for inference."""
        if task_name not in self.modulators:
            raise ValueError(
                f"Task '{task_name}' not found in modulators. Available tasks: {list(self.modulators.keys())}"
            )
        if self._current_task == task_name:
            return

        # unset previous task
        if self._current_task is not None:
            self.modulators[self._current_task].remove(self)
            assert (
                self._current_task is None
            ), "Current task should be None after removal."

        # set new task
        self.modulators[task_name].apply(self)
        self._current_task = task_name

    @property
    def current_task(self) -> Optional[str]:
        """Get the current task name."""
        return self._current_task

    def forward(self, *args, **kwargs) -> Any:
        """
        Forward pass with task-specific modulation.

        Args:
            *args: Positional arguments for the backbone model
            **kwargs: Keyword arguments for the backbone model

        Returns:
            Model output after applying task-specific modulation
        """
        if self._current_task is None:
            raise ValueError(
                "No task specified. Set current_task or provide 'task' argument."
            )

        return self.backbone(*args, **kwargs)


class TaskModulator(nn.Module, Generic[TorchModelType], ABC):
    """
    Lightweight, task-specific parameterization that modulates
    a shared representation.

    This is the base class for all task modulators. Subclasses should implement
    the `apply` method to define how the modulator adapts the backbone model
    for a specific task.
    """

    @abstractmethod
    def apply(self, modulated_model: "ModulatedModel[TorchModelType]"):
        """
        Apply task-specific modulation to the backbone model.

        Args:
            modulated_model: The modulated model
        """
        raise NotImplementedError("Subclasses must implement the apply method.")

    @abstractmethod
    def remove(self, modulated_model: "ModulatedModel[TorchModelType]"):
        """
        Remove task-specific modulation from the backbone model.
        This is called when switching tasks.

        Args:
            modulated_model: The modulated model
        """
        raise NotImplementedError("Subclasses must implement the remove method.")
