"""
This module contains the M2C encoder architecture.
"""

import sys
import os

sys.path.append(os.path.dirname(__file__))

from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

from visual_encodings import VisualEncoder
from sequence_encodings import SequentialEncoder


class M2CEncoder(nn.Module):
    """
    Implementation of the M2C Encoder.

    The M2C encoder convert layers of materials to a latent representation.
    (n_layers, 64, 64) -> (n_layers, latent_dim)

    After that, a sequence encoder is used to encode the sequence of layers.
    A CLS pooling can be used to get a single representation of the entire structure.
    """

    def __init__(
        self,
        # Visual Encoder parameters
        spatial_dim: int = 64,
        in_channels: int = 1,
        latent_dim: int = 128,
        channels: List[int] = [32, 64, 128],
        num_vision_heads: int = 8,
        groups: int = 8,
        vision_dropout: float = 0.1,
        modality: str = "both",
        use_phase: bool = False,
        # Sequence Encoder parameters
        out_dim: int = 768,
        max_seq_len: int = 8,
        num_heads: int = 16,
        num_layers: int = 6,
        d_ff: int = 2048,
        dropout: float = 0.1,
        kwargs_pe: dict = {
            "ini_freq_scale": 1.0,
            "tunable_freq_scale": True,
            "dropout": 0.0,
        },
        use_cls: bool = True,
        trainable_pe: bool = True,
    ) -> None:
        """
        Initialize the M2C Encoder.

        Args:
            spatial_dim (int): Spatial dimension of the input layers (assumed square).
            in_channels (int): Number of input channels (1 for grayscale).
            latent_dim (int): Dimension of the latent space for each layer.
                              It defines the space where each layer is encoded.
            channels (List[int]): List of channels for the convolutional layers.
            num_vision_heads (int): Number of attention heads for the visual encoder.
            groups (int): Number of groups for the group normalization.
            vision_dropout (float): Dropout rate for the visual encoder.
            use_spectral (bool): Whether to use 2d-fft
            use_phase (bool): Whether to use phase encoding.
            out_dim (int): Dimension of the output embedding.
                           It defines the space where the entire structure will be encoded at the end.
                           Note that the working dimension of the sequence encoder is latent_dim.
            max_seq_len (int): Maximum number of layers in the structure.
            num_heads (int): Number of attention heads for the sequence encoder.
            num_layers (int): Number of layers for the sequence encoder.
            d_ff (int): Dimension of the feed forward layer in the sequence encoder.
            dropout (float): Dropout rate for the sequence encoder.
            kwargs_pe (dict): Additional arguments for the positional encoding.
            use_cls (bool): Whether to use CLS token for pooling.
            trainable_pe (bool): Whether to use trainable positional encoding.
        """
        super().__init__()
        self.visual_encoder = VisualEncoder(
            spatial_dim=spatial_dim,
            in_channels=in_channels,
            latent_dim=latent_dim,
            channels=channels,
            num_heads=num_vision_heads,
            groups=groups,
            dropout=vision_dropout,
            modality=modality,
            use_phase=use_phase,
        )
        self.sequence_encoder = SequentialEncoder(
            dim=latent_dim,
            out_dim=out_dim,
            max_seq_len=max_seq_len,
            num_heads=num_heads,
            num_layers=num_layers,
            d_ff=d_ff,
            dropout=dropout,
            kwargs_pe=kwargs_pe,
            use_cls=use_cls,
            trainable_pe=trainable_pe,
        )

    def forward(
        self,
        x: torch.Tensor,
        thicknesses: torch.Tensor,
        src_mask: torch.Tensor,
        pool: bool = True,
    ) -> torch.Tensor:
        """
        Forward method for the M2C-Encoder

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, n_layers, spatial_dim, spatial_dim).
            thicknesses (torch.Tensor): Thicknesses of the layers of shape (batch_size, n_layers).
            src_mask (torch.Tensor): Source mask for attention of shape (batch_size, n_layers).
            pool (bool): Whether to pool the output to a single vector using CLS token.

        Returns:
            torch.Tensor: The output tensor of shape (batch_size, out_dim) if pool is True,
                          else (batch_size, n_layers, out_dim).
        """
        # Encode the visual features
        # (B, n_layers, spatial_dim, spatial_dim) -> (B, n_layers, latent_dim)
        B, n_layers, W, H = x.shape
        x = x.reshape(B * n_layers, 1, W, H)
        visual_features = self.visual_encoder(x)
        visual_features = visual_features.reshape(B, n_layers, -1)

        # Encode the sequence of visual features
        # (B, n_layers, latent_dim) -> (B, n_layers, out_dim)
        sequence_output = self.sequence_encoder(visual_features, thicknesses, src_mask)

        first = sequence_output[:, 0, :]

        if pool:
            return first
        else:
            sequence_output = torch.cat(
                (first.unsqueeze(1), sequence_output[:, 1:, :]), dim=1
            )
            return sequence_output


if __name__ == "__main__":
    x = torch.randn(2, 5, 64, 64)
    thicknesses = torch.randn(2, 5)
    src_mask = torch.zeros(2, 5).bool()
    model = M2CEncoder(
        spatial_dim=64,
        in_channels=1,
        latent_dim=128,
        channels=[32, 64, 128],
        num_vision_heads=8,
        groups=8,
        vision_dropout=0.1,
        modality="both",
        use_phase=False,
        out_dim=768,
        max_seq_len=5,
        num_heads=16,
        num_layers=6,
        d_ff=2048,
        dropout=0.1,
        kwargs_pe={
            "ini_freq_scale": 1.0,
            "tunable_freq_scale": True,
            "dropout": 0.0,
        },
        use_cls=True,
        trainable_pe=True,
    )
    out = model(x, thicknesses, src_mask, pool=True)
    print(out.shape)  # (2, 768)
    out = model(x, thicknesses, src_mask, pool=False)
    print(out.shape)  # (2, 5+1, 768)
