"""Output adapters for Perciever IO
"""

import math

from torch import Tensor, nn


class OutputAdapter(nn.Module):
    """blah"""

    def __init__(self, output_shape: tuple[int, ...]):
        super().__init__()
        self._output_shape = output_shape

    @property
    def names(self) -> list[str]:
        raise NotImplementedError()

    @property
    def output_shape(self):
        """_summary_

        :return: _description_
        """
        return self._output_shape


class OccupancyOA(OutputAdapter):
    """"""

    def __init__(
        self,
        num_output_channels: int,
        name: str = "agents_occ",
        image_shape: list[int] | None = None,
    ):
        self._name = name
        self.image_shape = [200, 200] if image_shape is None else image_shape

        super().__init__(
            output_shape=(math.prod(self.image_shape), num_output_channels)
        )
        self.linear = nn.Linear(num_output_channels, 1)

    @property
    def names(self) -> list[str]:
        return [self._name]

    def forward(self, x: Tensor) -> Tensor:
        """Forward Impl.

        :param x: input tensor
        :return: occupancy
        """
        out: Tensor = self.linear(x)
        out = out.permute(0, 2, 1)
        out = out.reshape(x.shape[0], 1, *self.image_shape)
        return out


class ClassOccupancyOA(OutputAdapter):
    """"""

    def __init__(
        self,
        names: list[str],
        num_output_channels: int,
        image_shape: list[int] | None = None,
    ):
        self.image_shape = [200, 200] if image_shape is None else image_shape

        super().__init__(
            output_shape=(math.prod(self.image_shape), num_output_channels)
        )

        self.linear_layers = nn.ModuleDict(
            {f"{name}_occ": nn.Linear(num_output_channels, 1) for name in names}
        )

    @property
    def names(self) -> list[str]:
        return list(self.linear_layers.keys())

    def forward(self, x: Tensor):
        """Forward Impl.

        :param x: input tensor
        :return: per-class occupancy
        """
        outputs = {}
        for name, module in self.linear_layers.items():
            result: Tensor = module(x)
            outputs[name] = result.permute(0, 2, 1).reshape(
                x.shape[0], 1, *self.image_shape
            )

        return outputs


class MatchingOA(OutputAdapter):
    """"""

    def __init__(self, name: str = "matching"):
        super().__init__((1,))

        self._name = name

    @property
    def names(self) -> list[str]:
        return [self._name]

    def forward(self, x: Tensor):
        return {self._name: x}
