"""Input adapters for Perciever IO
"""

import enum
from typing import Optional

import einops
import torch
from torch import Tensor, nn


def _debug_plot(tensor: Tensor, figname: str) -> None:
    """Simple function to call when debugging"""
    import cv2
    import numpy as np

    im = tensor.clone().detach().cpu().numpy()
    im_norm = np.zeros_like(im, dtype=np.uint8)
    cv2.normalize(im, im_norm, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
    cv2.imwrite(figname, im_norm)


def generate_positions_for_encoding(spatial_shape, v_min=-1.0, v_max=1.0):
    """
    Create evenly spaced position coordinates for
    spatial_shape with values in [v_min, v_max].

    :param v_min: minimum coordinate value per dimension.
    :param v_max: maximum coordinate value per dimension.
    :return: position coordinates tensor of shape (*shape, len(shape)).
    """
    coords = [torch.linspace(v_min, v_max, steps=s) for s in spatial_shape]
    return torch.stack(torch.meshgrid(*coords, indexing="ij"), dim=len(spatial_shape))


def generate_position_encodings(
    p: Tensor,
    num_frequency_bands: int,
    max_frequencies: Optional[tuple[int, ...]] = None,
    include_positions: bool = True,
) -> Tensor:
    """Fourier-encode positions p using num_frequency_bands.

    :param p: positions of shape (*d, c) where c = len(d).
    :param max_frequencies: maximum frequency for each dimension (1-tuple for sequences,
           2-tuple for images, ...). If `None` values are derived from shape of p.
    :param include_positions: whether to include input positions p in returned encodings tensor.
    :returns: position encodings tensor of shape (*d, c * (2 * num_bands + include_positions)).
    """
    encodings = []

    if max_frequencies is None:
        max_frequencies = p.shape[:-1]

    frequencies = [
        torch.linspace(1.0, max_freq / 2.0, num_frequency_bands, device=p.device)
        for max_freq in max_frequencies
    ]
    frequency_grids = []

    for i, frequencies_i in enumerate(frequencies):
        frequency_grids.append(p[..., i : i + 1] * frequencies_i[None, ...])

    if include_positions:
        encodings.append(p)

    encodings.extend(
        [torch.sin(torch.pi * frequency_grid) for frequency_grid in frequency_grids]
    )
    encodings.extend(
        [torch.cos(torch.pi * frequency_grid) for frequency_grid in frequency_grids]
    )

    return torch.cat(encodings, dim=-1)


def _sample_frequency_band(
    p: Tensor,
    num_frequency_bands: int,
    max_frequencies: list[int],
    include_positions: bool = True,
) -> Tensor:
    """
    Samples fourier encoding at a relative position coded such that
    three dimensions can be used for spatio-temporal vehicle input
    """
    frequencies = [
        torch.linspace(
            1.0, max_freq / 2.0, num_frequency_bands, device=p.device, dtype=p.dtype
        )
        for max_freq in max_frequencies
    ]
    frequency_grids = torch.pi * torch.cat(
        [
            p[..., i : i + 1] * frequencies_i[None, ...]
            for i, frequencies_i in enumerate(frequencies)
        ],
        dim=-1,
    )

    encodings = [p] if include_positions else []
    encodings.extend([torch.sin(frequency_grids), torch.cos(frequency_grids)])

    return torch.cat(encodings, dim=-1)


class InputAdapter(nn.Module):
    """blah."""

    def __init__(self, out_channels):
        super().__init__()
        self._out_channels = out_channels

    @property
    def out_channels(self):
        return self._out_channels

    def forward(self, x):
        raise NotImplementedError()


class ImageIA(InputAdapter):
    """_summary_

    :param InputAdapter: _description_
    :type InputAdapter: _type_
    """

    def __init__(
        self,
        image_shape: tuple[int, ...],
        num_frequency_bands: int,
        patchify: int = 1,
        conv_1x1: int | None = None,
        in_channels: int = 3,
    ):
        num_image_channels, *self.spatial_shape = image_shape
        self.image_shape = tuple(image_shape)
        self.num_frequency_bands = num_frequency_bands
        self.patch_size = patchify  # basically no patching if 1

        if patchify > 1:
            self.spatial_shape = [s // self.patch_size for s in self.spatial_shape]

        if conv_1x1 is not None:
            num_image_channels = conv_1x1

        super().__init__(
            out_channels=num_image_channels + self._num_position_encoding_channels()
        )

        if conv_1x1 is not None:
            if patchify > 1:
                in_channels *= patchify**2
            self.conv_1x1 = nn.Conv2d(in_channels, conv_1x1, 1)
        else:
            self.conv_1x1 = None

        # create encodings for single example
        pos = generate_positions_for_encoding(self.spatial_shape)
        enc = generate_position_encodings(pos, self.num_frequency_bands)

        # flatten encodings along spatial dimensions
        enc = einops.rearrange(enc, "... c -> (...) c")

        # position encoding prototype
        self.register_buffer("position_encoding", enc, persistent=False)

    def _num_position_encoding_channels(self, include_positions: bool = True) -> int:
        return len(self.spatial_shape) * (
            2 * self.num_frequency_bands + include_positions
        )

    def forward(self, x):
        b, *d = x.shape

        if tuple(d) != self.image_shape:
            raise ValueError(
                f"Input image shape {tuple(d)} different from required shape {self.image_shape}"
            )

        if self.patch_size > 1:
            x = einops.rearrange(
                x,
                "b c (h dh) (w dw) -> b (dh dw c) h w",
                dh=self.patch_size,
                dw=self.patch_size,
            )

        if self.conv_1x1 is not None:
            x = self.conv_1x1(x)

        x = einops.rearrange(x, "b c ... -> b (...) c")

        # repeat position encoding along batch dimension
        x_enc = einops.repeat(self.position_encoding, "... -> b ...", b=b)
        return torch.cat([x, x_enc], dim=-1)


class AgentIA(InputAdapter):
    """Input adapter that tokenizes agent inputs"""

    class ClassMode(enum.Enum):
        NONE = enum.auto()
        VALUE = enum.auto()
        ONE_HOT = enum.auto()
        EMBEDDING = enum.auto()

    class InputMode(enum.Enum):
        RAW = enum.auto()
        FPOS = enum.auto()
        FPOS_EXTRA = enum.auto()  # Extra are appended dx, dy, dth

    def __init__(
        self,
        input_mode: str | InputMode,
        pos_feats: int = 3,
        other_feats: int = 0,
        pos_freq: int = 20,
        yaw_freq: int = 16,
        num_frequency_bands: int = 32,
        n_classes: int = 2,
        class_mode: str | ClassMode = ClassMode.NONE,
        embedding_dim: int = 64,
    ):
        self.pos_feats = pos_feats
        self.other_feats = other_feats
        self.n_classes = n_classes
        if isinstance(input_mode, str):
            input_mode = self.InputMode[input_mode.upper()]
        self.input_mode = input_mode
        if isinstance(class_mode, str):
            class_mode = self.ClassMode[class_mode.upper()]
        self.class_mode = class_mode

        num_input_channels = {
            AgentIA.InputMode.RAW: pos_feats + other_feats,
            AgentIA.InputMode.FPOS: num_frequency_bands * pos_feats * 2,
            AgentIA.InputMode.FPOS_EXTRA: num_frequency_bands * pos_feats * 2
            + other_feats,
        }[self.input_mode]

        if self.class_mode is AgentIA.ClassMode.VALUE:
            num_input_channels += 1
        elif self.class_mode is AgentIA.ClassMode.ONE_HOT:
            num_input_channels += n_classes
        elif self.class_mode is AgentIA.ClassMode.EMBEDDING:
            num_input_channels += embedding_dim

        self.yaw_freq = yaw_freq
        self.pos_freq = pos_freq
        self.num_frequency_bands = num_frequency_bands
        super().__init__(num_input_channels)

        if self.class_mode is AgentIA.ClassMode.EMBEDDING:
            self.class_embeddings = nn.Parameter(torch.empty(n_classes, embedding_dim))
            nn.init.xavier_uniform_(self.class_embeddings)

    def forward(self, x: Tensor, pad_mask: Tensor) -> tuple[Tensor, Tensor]:
        """Pad mask is true for invalid values for pytorch, we want the opposite in IA"""
        if self.input_mode is AgentIA.InputMode.RAW:
            return x, pad_mask

        enc_x = _sample_frequency_band(
            x,
            num_frequency_bands=self.num_frequency_bands,
            max_frequencies=[self.pos_freq, self.pos_freq, self.yaw_freq],
            include_positions=False,
        )
        if self.input_mode is AgentIA.InputMode.FPOS_EXTRA:
            other_feats_range = slice(self.pos_feats, self.pos_feats + self.other_feats)
            enc_x = torch.cat([enc_x, x[..., other_feats_range]], dim=-1)

        if self.class_mode is AgentIA.ClassMode.VALUE:
            enc_x = torch.cat([enc_x, x[..., -1, None]], dim=-1)
        elif self.class_mode is AgentIA.ClassMode.ONE_HOT:
            onehot = torch.zeros(
                [*enc_x.shape[:-1], self.n_classes + 1], device=enc_x.device
            )
            onehot.scatter_(
                2, x[..., -1, None].to(torch.int64), torch.ones_like(x[..., [-1]])
            )
            enc_x = torch.cat([enc_x, onehot[..., 1:]], dim=-1)
        elif self.class_mode is AgentIA.ClassMode.EMBEDDING:
            class_idxs = x[..., -1].int()
            enc_x = torch.cat([enc_x, self.class_embeddings[class_idxs]], dim=-1)

        return enc_x, pad_mask


def pad_agent_mask(agent: Tensor, mask: Tensor | None = None):
    """Pad batch of agent features with one extra token of zeros.
    This can help prevent NaN if there are no observed agents.
    The mask is a "valid" mask, hence ones will be appended to mask.
    This is opposite MHA where "True" means an element is padding and to not attend.

    Args:
        agent (Tensor): Agent features of shape [...,N,C]
        mask (Tensor|None): Mask tensor of shape [..., N] or None

    Return:
        tuple[Tensor,Tensor|None]: Padded agent and mask (if not None)
    """
    # ensure that at least one dummy token isn't masked to prevent NaN's
    agent = torch.cat([agent, torch.zeros_like(agent[..., 0, None, :])], dim=-2)
    if mask is not None:
        mask = torch.cat([mask, torch.ones_like(mask[..., 0, None])], dim=-1)
    return agent, mask


def unpad_agent_mask(agent: Tensor, mask: Tensor | None = None):
    """Remove padding appended by `pad_agent_mask`.

    Args:
        agent (Tensor): Agent features of shape [...,N,C]
        mask (Tensor): Mask tensor of shape [..., N]

    Return:
        tuple[Tensor,Tensor]: Padded agent and mask
    """
    agent = agent[..., :-1, :]
    if mask is not None:
        mask = mask[..., :-1]
    return agent, mask
