"""
This module contains visual encodings for the Material-to-Context (M2C) encoder.
"""

from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from torchvision.transforms import Lambda


class ResidualBlock(nn.Module):
    """
    A residual block as commonly used in VAE architectures.
    Made of two convolutions, and group normalization.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "circular",
        dropout: float = 0.1,
    ):
        """
        __init__ method for the ResidualBlock class.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Size of the convolution kernel.
            stride (int): Stride of the convolution.
            padding (int): Padding for the convolution.
            groups (int): Number of groups for group normalization.
            bias (bool): Whether to use bias in the convolution.
            padding_mode (str): Padding mode for the convolution.
            dropout (float): Dropout rate.
        """
        super().__init__()

        assert in_channels % groups == 0, "in_channels must be divisible by groups."
        assert out_channels % groups == 0, "out_channels must be divisible by groups."

        self.norm1 = nn.GroupNorm(groups, in_channels)
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=bias,
            padding_mode=padding_mode,
        )

        self.norm2 = nn.GroupNorm(groups, out_channels)
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=bias,
            padding_mode=padding_mode,
        )

        self.dropout = nn.Dropout(dropout)

        if in_channels == out_channels:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = nn.Conv2d(
                in_channels, out_channels, kernel_size=1, padding=0
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward method for the ResidualBlock class.

        Args:
            x (torch.Tensor): Input tensor to be processed.
            [batch, in_channels, in_height, in_width]

        Returns:
            torch.Tensor: Processed tensor.
            [batch, out_channels, out_height, out_width]
        """
        shortcut = self.shortcut(x)

        x = self.norm1(x)
        x = F.silu(x)
        x = self.conv1(x)
        x = self.dropout(x)

        x = self.norm2(x)
        x = F.silu(x)
        x = self.conv2(x)
        x = self.dropout(x)

        return x + shortcut


class AttentionBlock(nn.Module):
    """
    An attention block as commonly used in VAE architectures.
    """

    def __init__(
        self,
        channels: int,
        groups: int = 32,
        num_heads: int = 1,
        bias: bool = True,
        dropout: float = 0.1,
    ):
        """
        __init__ method for the AttentionBlock class.

        Args:
            channels (int): Number of channels.
            groups (int): Number of groups for group normalization.
            num_heads (int): Number of heads in the multi-head attention.
            bias (bool): Whether to use bias in the attention layer.
            dropout (float): Dropout rate.
        """
        super().__init__()
        assert channels % groups == 0, "channels must be divisible by groups."

        self.norm = nn.GroupNorm(groups, channels)
        self.attention = nn.MultiheadAttention(
            embed_dim=channels,
            num_heads=num_heads,
            batch_first=True,
            bias=bias,
            dropout=dropout,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.norm(x)

        b, c, h, w = x.shape
        x = x.view(b, c, h * w).transpose(2, 1)
        x = self.attention(x, x, x, need_weights=False)
        x = x[0].transpose(2, 1).view(b, c, h, w)

        return x + shortcut


class SpatialEncoderBlock(nn.Module):
    """
    The encoding path for the spatial features.
    This block is intended to be used at the beginning of the encoder.
    """

    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        kernel_size: int = 3,
        padding: int = 1,
        bias: bool = True,
        padding_mode: str = "circular",
        dropout: float = 0.1,
    ):
        """
        __init__ method for the SpatialEncoderBlock class.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Size of the convolution kernel.
            padding (int): Padding for the convolution.
            bias (bool): Whether to use bias in the convolution.
            padding_mode (str): Padding mode for the convolution.
            dropout (float): Dropout rate.
        """
        super().__init__()

        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            padding=padding,
            bias=bias,
            padding_mode=padding_mode,
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward method for the SpatialEncoderBlock class.

        Args:
            x (torch.Tensor): Input tensor to be processed.
            [batch, in_channels, in_height, in_width]

        Returns:
            torch.Tensor: Processed tensor.
            [batch, out_channels, out_height, out_width]
        """
        x = self.conv(x)
        x = F.silu(x)
        x = self.dropout(x)

        return x


class SpectralEncoderBlock(nn.Module):
    """
    Encode periodic structure via 2D FFT.
    This block is intended to be used at the beginning of the encoder.

    If use_phase=False, the output is translation-invariant (magnitude only).
    If use_phase=True, we output [|F|, cos(phi), sin(phi)] channels (not shift-invariant).
    """

    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        kernel_size: int = 3,
        padding: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        dropout: float = 0.1,
        use_phase: bool = False,
    ):
        """
        __init__ method for the SpatialEncoderBlock class.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Size of the convolution kernel.
            padding (int): Padding for the convolution.
            bias (bool): Whether to use bias in the convolution.
            padding_mode (str): Padding mode for the convolution.
            dropout (float): Dropout rate.
        """
        super().__init__()
        self.use_phase = use_phase

        spec_channels = 3 if use_phase else 1
        self.conv = nn.Conv2d(
            spec_channels * in_channels,
            out_channels,
            kernel_size,
            padding=padding,
            bias=bias,
            padding_mode=padding_mode,
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward method for the SpectralEncoderBlock class.

        Args:
            x (torch.Tensor): Input tensor to be processed.
            [batch, in_channels, in_height, in_width]

        Returns:
            torch.Tensor: Processed tensor.
            [batch, out_channels, out_height, out_width]
        """
        x = torch.fft.fft2(x, dim=(-2, -1), norm="ortho")
        mag = torch.log1p(torch.abs(x) + 1e-6)

        if self.use_phase:
            phase = torch.angle(x)
            x = torch.cat(
                [mag, torch.cos(phase), torch.sin(phase)], dim=1
            )  # (B, 3*C, H, W)
        else:
            x = mag  # (B, C, H, W)

        x = self.conv(x)
        x = F.silu(x)
        x = self.dropout(x)

        return x


class DownsamplingBlock(nn.Module):
    """
    A simple downsampling block based on strided convolutions and
    the residual connections.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        padding: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "circular",
        dropout: float = 0.1,
    ):
        """
        __init__ method for the DownsamplingBlock class.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Size of the convolution kernel.
            padding (int): Padding for the convolution.
            groups (int): Number of groups for group normalization.
            bias (bool): Whether to use bias in the convolution.
            padding_mode (str): Padding mode for the convolution.
            dropout (float): Dropout rate.
        """
        super().__init__()

        self.residual_1 = ResidualBlock(
            in_channels,
            out_channels,
            kernel_size,
            1,
            padding,
            groups,
            bias,
            padding_mode,
            dropout,
        )
        self.residual_2 = ResidualBlock(
            out_channels,
            out_channels,
            kernel_size,
            1,
            padding,
            groups,
            bias,
            padding_mode,
            dropout,
        )

        self.conv = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size,
            1,
            padding,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        )
        self.downsample = nn.AvgPool2d(2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward method for the DownsamplingBlock class.

        Args:
            x (torch.Tensor): Input tensor to be processed.
            [batch, in_channels, in_height, in_width]

        Returns:
            torch.Tensor: Processed tensor.
            [batch, out_channels, out_height//2, out_width//2]
        """
        x = self.residual_1(x)
        x = self.residual_2(x)

        x = self.conv(x)
        x = self.downsample(x)

        return x


class UpsamplingBlock(nn.Module):
    """
    A simple upsampling block based on bilinear interpolation and
    the residual connections.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        padding: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "circular",
        dropout: float = 0.1,
    ):
        """
        __init__ method for the UpsamplingBlock class.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Size of the convolution kernel.
            padding (int): Padding for the convolution.
            groups (int): Number of groups for group normalization.
            bias (bool): Whether to use bias in the convolution.
            padding_mode (str): Padding mode for the convolution.
            dropout (float): Dropout rate.
        """
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")

        self.conv = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size,
            1,
            padding,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        )

        self.residual_1 = ResidualBlock(
            in_channels,
            out_channels,
            kernel_size,
            1,
            padding,
            groups,
            bias,
            padding_mode,
            dropout,
        )
        self.residual_2 = ResidualBlock(
            out_channels,
            out_channels,
            kernel_size,
            1,
            padding,
            groups,
            bias,
            padding_mode,
            dropout,
        )
        self.residual_3 = ResidualBlock(
            out_channels,
            out_channels,
            kernel_size,
            1,
            padding,
            groups,
            bias,
            padding_mode,
            dropout,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward method for the UpsamplingBlock class.

        Args:
            x (torch.Tensor): Input tensor to be processed.
            [batch, in_channels, in_height, in_width]

        Returns:
            torch.Tensor: Processed tensor.
            [batch, out_channels, out_height*2, out_width*2]
        """
        x = self.upsample(x)

        x = self.conv(x)

        x = self.residual_1(x)
        x = self.residual_2(x)
        x = self.residual_3(x)

        return x


class AttentionPool2d(nn.Module):
    """
    An attention-based pooling layer.

    implementation from: https://github.com/revantteotia/clip-training/blob/main/model/model.py
    """

    def __init__(
        self, spatial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
    ):
        """
        __init__ method for the AttentionPool2d class.

        Args:
            spacial_dim (int): Spacial dimension of the input tensor.
            embed_dim (int): Embedding dimension.
            num_heads (int): Number of heads in the multi-head attention.
            output_dim (int): Output dimension. If None, output_dim = embed_dim.
        """
        super().__init__()
        self.positional_embedding = nn.Parameter(
            torch.randn(spatial_dim**2 + 1, embed_dim) / embed_dim**0.5
        )
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
            2, 0, 1
        )  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x,
            key=x,
            value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat(
                [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
            ),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False,
        )

        return x[0]


class VisualEncoder(nn.Module):
    """
    Encode the input tensor into the latent space.
    """

    def __init__(
        self,
        spatial_dim: int = 64,
        in_channels: int = 1,
        latent_dim: int = 4,
        channels: List[int] = [128, 256, 512],
        num_heads: int = 8,
        groups: int = 32,
        dropout: float = 0.1,
        modality: str = "both",
        use_phase: bool = False,
    ):
        """
        __init__ method for the Encoder class.

        Args:
            spatial_dim (int): Spatial dimension of the input tensor.
            in_channels (int): Number of input channels.
            latent_dim (int): Dimension of the latent space.
            channels (List[int]): Number of channels in each layer.
            num_heads (int): Number of heads in the multi-head attention.
            num_groups (int): Number of groups in the group normalization.
            dropout (float): Dropout rate.
        """
        super(VisualEncoder, self).__init__()
        self.modality = modality

        if modality == "spectral":
            padding_mode = "zeros"
            self.encoder = SpectralEncoderBlock(
                in_channels,
                channels[0],
                dropout=dropout,
                use_phase=use_phase,
                padding_mode=padding_mode,
            )
        elif modality == "pixmap":
            padding_mode = "circular"
            self.encoder = SpatialEncoderBlock(
                in_channels, channels[0], dropout=dropout, padding_mode=padding_mode
            )
        elif modality == "both":
            padding_mode = "circular"
            self.spatial_encoder = SpatialEncoderBlock(
                in_channels,
                channels[0] // 2,
                dropout=dropout,
                padding_mode=padding_mode,
            )
            self.spectral_encoder = SpectralEncoderBlock(
                in_channels,
                channels[0] // 2,
                dropout=dropout,
                use_phase=use_phase,
                padding_mode="zeros",
            )
            self.merge_conv = nn.Conv2d(
                channels[0],
                channels[0],
                kernel_size=1,
                padding=0,
                padding_mode=padding_mode,
            )
        else:
            raise ValueError(
                f"Invalid modality: {modality}. Choose from 'pixmap', 'spectral', 'both'."
            )

        self.downsampling_blocks = nn.Sequential(
            *[
                DownsamplingBlock(
                    _in_channels,
                    _out_channels,
                    groups=groups,
                    dropout=dropout,
                    padding_mode=padding_mode,
                )
                for _in_channels, _out_channels in zip(
                    [channels[0]] + channels[:-1], channels
                )
            ]
        )

        self.bottleneck = nn.Sequential(
            ResidualBlock(
                channels[-1],
                channels[-1],
                groups=groups,
                dropout=dropout,
                padding_mode=padding_mode,
            ),
            ResidualBlock(
                channels[-1],
                channels[-1],
                groups=groups,
                dropout=dropout,
                padding_mode=padding_mode,
            ),
            ResidualBlock(
                channels[-1],
                channels[-1],
                groups=groups,
                dropout=dropout,
                padding_mode=padding_mode,
            ),
            AttentionBlock(
                channels[-1], groups=groups, num_heads=num_heads, dropout=dropout
            ),
            ResidualBlock(
                channels[-1],
                channels[-1],
                groups=groups,
                dropout=dropout,
                padding_mode=padding_mode,
            ),
            nn.GroupNorm(groups, channels[-1]),
            nn.SiLU(),
            nn.Conv2d(
                channels[-1],
                latent_dim * 2,
                kernel_size=3,
                padding=1,
                padding_mode=padding_mode,
            ),
            nn.Conv2d(latent_dim * 2, latent_dim, kernel_size=1, padding=0),
        )

        self.attn_pool = AttentionPool2d(
            spatial_dim=spatial_dim // (2 ** len(channels)),
            embed_dim=latent_dim,
            num_heads=num_heads,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward method for the Encoder class.

        Args:
            x (torch.Tensor): Input tensor to be encoded.
            [batch, in_channels, height, width]

        Returns:
            torch.Tensor: Encoded tensor in the latent space.
            [batch, latent_dim]
        """
        assert x.dim() == 4, "Input tensor must have 4 dimensions."

        if self.modality == "both":
            x_spatial = self.spatial_encoder(x)
            x_spectral = self.spectral_encoder(x)
            x = torch.cat([x_spatial, x_spectral], dim=1)
            x = self.merge_conv(x)
        else:
            x = self.encoder(x)
        x = self.downsampling_blocks(x)
        x = self.bottleneck(x)

        return self.attn_pool(x)


if __name__ == "__main__":
    x = torch.randn(2, 1, 64, 64)
    model = VisualEncoder(
        spatial_dim=64,
        in_channels=1,
        latent_dim=768,
        channels=[128, 256, 512],
        num_heads=8,
        groups=32,
        dropout=0.1,
        use_spectral=True,
        use_phase=False,
    )

    y = model(x, None)
    print(y.shape)
