"""
This module contains the implementation of the Vision Transformer (ViT) architecture.
Mainly for the Band Diagram (BD) encoder.

Code inspired from: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
"""

import sys
import os

sys.path.append(os.path.dirname(__file__))

from typing import List

import torch
from torch import nn
import torch.nn.functional as F

from sequence_encodings import SequentialEncoderBlock
from visual_encodings import AttentionBlock, ResidualBlock, UpsamplingBlock

from einops import repeat
from einops.layers.torch import Rearrange


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


class Transformer(nn.Module):
    def __init__(self, dim, depth, num_heads, d_ff, dropout=0.0):
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(depth):
            self.layers.append(
                SequentialEncoderBlock(
                    dim=dim,
                    num_heads=num_heads,
                    d_ff=d_ff,
                    dropout=dropout,
                )
            )

    def forward(self, x, mask=None):
        for block in self.layers:
            x = block(x, mask)

        return x


class ViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        dim,
        out_dim,
        depth,
        heads,
        mlp_dim,
        pool="cls",
        channels=1,
        dropout=0.0,
        emb_dropout=0.0,
        patch_dropout=0.25,
    ):
        super().__init__()
        self.patch_dropout = patch_dropout

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, (
            "Image dimensions must be divisible by the patch size."
        )

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {"cls", "mean"}, (
            "pool type must be either cls (cls token) or mean (mean pooling)"
        )

        self.to_patch_embedding = nn.Sequential(
            Rearrange(
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=patch_height,
                p2=patch_width,
            ),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)

        self.pool = pool

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, out_dim * 2),
            nn.GELU(),
            nn.LayerNorm(out_dim * 2),
            nn.Linear(out_dim * 2, out_dim),
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding[:, : (n + 1)]
        x = self.dropout(x)

        mask = torch.zeros(b, n + 1).bool().to(x.device)
        if self.patch_dropout > 0.0 and self.training:
            drop_mask = torch.rand(b, n + 1).to(x.device) < self.patch_dropout
            drop_mask[:, 0] = False  # never drop the CLS token
            mask = mask | drop_mask

        x = self.transformer(x, mask=mask)

        x = x.mean(dim=1) if self.pool == "mean" else x[:, 0]

        return self.mlp_head(x)


class ViTVAEEncoder(nn.Module):
    """
    Wraps the ViT to produce VAE latents on a grid.
    """

    def __init__(
        self,
        vit,  # ViT instance
        image_size,  # (H, W) or int
        patch_size,  # (Ph, Pw) or int; should match vit init
        latent_channels=4,  # C_lat (e.g., 4 for SD-like latents)
        emb_dropout=0.0,
    ):
        super().__init__()
        self.vit = vit
        H, W = pair(image_size)
        Ph, Pw = pair(patch_size)
        assert H % Ph == 0 and W % Pw == 0
        self.Hp, self.Wp = H // Ph, W // Pw
        self.num_patches = self.Hp * self.Wp

        # project each *patch token* to mu and logvar of size C_lat
        D = vit.pos_embedding.shape[-1]  # token dim
        self.drop = nn.Dropout(emb_dropout)
        self.head_mu = nn.Linear(D, latent_channels)
        self.head_logv = nn.Linear(D, latent_channels)

    def forward(self, img):
        """
        img: (B, C, H, W)
        returns z, mu, logvar with shape (B, C_lat, Hp, Wp)
        """
        B = img.size(0)

        x = self.vit.to_patch_embedding(img)  # (B, L, D)
        b, n, d = x.shape
        assert n == self.num_patches, "image_size/patch_size must match vit setup"

        cls_tokens = repeat(self.vit.cls_token, "1 1 d -> b 1 d", b=B)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, L+1, D)
        x = x + self.vit.pos_embedding[:, : (n + 1)]
        x = self.drop(x)

        # patch (token) dropout mask (same semantics as your ViT)
        mask = torch.zeros(B, n + 1, dtype=torch.bool, device=x.device)
        if self.vit.patch_dropout > 0.0 and self.training:
            drop = torch.rand(B, n + 1, device=x.device) < self.vit.patch_dropout
            drop[:, 0] = False  # never drop CLS
            mask = mask | drop

        x = self.vit.transformer(x, mask=mask)  # (B, L+1, D)

        # ---- discard CLS; keep patch tokens only ----
        tokens = x[:, 1:, :]  # (B, L, D), L = Hp*Wp

        # --- Compute the CLS encodings ---
        cls_encoding = self.vit.mlp_head(x[:, 0, :])  # (B, D_out)
        # Note: not used for VAE, but useful to generate a global context vector

        # ---- VAE heads per token ----
        mu = self.head_mu(tokens)  # (B, L, C_lat)
        logv = self.head_logv(tokens)  # (B, L, C_lat)
        std = (0.5 * logv).exp()
        eps = torch.randn_like(std)
        z_tok = mu + std * eps  # (B, L, C_lat)

        # ---- reshape tokens back to grid ----
        def to_grid(t):  # (B, L, C_lat) -> (B, C_lat, Hp, Wp)
            return t.transpose(1, 2).reshape(B, -1, self.Hp, self.Wp)

        z = to_grid(z_tok)
        mu = to_grid(mu)
        logv = to_grid(logv)

        return z, mu, logv, cls_encoding


class Decoder(nn.Module):
    """
    Decode the latent tensor back to the original space.
    """

    def __init__(
        self,
        out_channels: int = 1,
        latent_dim: int = 4,
        channels: List[int] = [128, 64, 32],
        num_heads: int = 8,
        groups: int = 16,
        dropout: float = 0.1,
    ):
        """
        __init__ method for the Decoder class.

        Args:
            out_channels (int): Number of output channels.
            latent_dim (int): Dimension of the latent space.
            channels (List[int]): Number of channels in each layer.
            num_heads (int): Number of heads in the multi-head attention.
            num_groups (int): Number of groups in the group normalization.
            dropout (float): Dropout rate.
        """
        super(Decoder, self).__init__()

        self.bottleneck = nn.Sequential(
            nn.Conv2d(latent_dim, latent_dim, kernel_size=1, padding=0),
            nn.Conv2d(
                latent_dim,
                channels[0],
                kernel_size=3,
                padding=1,
                padding_mode="zeros",
            ),
            ResidualBlock(channels[0], channels[0], groups=groups, dropout=dropout),
            AttentionBlock(
                channels[0], groups=groups, num_heads=num_heads, dropout=dropout
            ),
            ResidualBlock(channels[0], channels[0], groups=groups, dropout=dropout),
            ResidualBlock(channels[0], channels[0], groups=groups, dropout=dropout),
            ResidualBlock(channels[0], channels[0], groups=groups, dropout=dropout),
        )

        self.upsampling_blocks = nn.Sequential(
            *[
                UpsamplingBlock(
                    _in_channels, _out_channels, groups=groups, dropout=dropout
                )
                for _in_channels, _out_channels in zip(
                    [channels[0]] + channels[:-1], channels
                )
            ]
        )

        self.output = nn.Sequential(
            nn.GroupNorm(groups, channels[-1]),
            nn.SiLU(),
            nn.Conv2d(
                channels[-1],
                out_channels,
                kernel_size=3,
                padding=1,
                padding_mode="replicate",
            ),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward method for the Decoder class.

        Args:
            x (torch.Tensor): Latent tensor to be decoded.
            [batch, latent_dim, height//2**len(channels), width//2**len(channels)]
            rescale (bool): Flag to rescale the output tensor.

        Returns:
            torch.Tensor: Decoded tensor in the original space.
            [batch, out_channels, height, width]
        """
        x = self.bottleneck(x)
        x = self.upsampling_blocks(x)
        x = self.output(x)

        return x


class ViTVAE(nn.Module):
    def __init__(
        self,
        vit,
        image_size,
        patch_size,
        out_channels=1,
        latent_channels=4,
        channels=[256, 128, 32],
        decoder_heads=8,
        decoder_groups=16,
        dropout=0.1,
    ):
        super().__init__()
        self.encoder = ViTVAEEncoder(
            vit=vit,
            image_size=image_size,
            patch_size=patch_size,
            latent_channels=latent_channels,
        )
        Ph, Pw = pair(patch_size)
        assert Ph == Pw, (
            "ConvDecoder up_factor assumes square patches; else use two-step upsample."
        )
        self.decoder = Decoder(
            out_channels=out_channels,
            latent_dim=latent_channels,
            channels=channels,
            num_heads=decoder_heads,
            groups=decoder_groups,
            dropout=dropout,
        )

    def encode(self, x):
        return self.encoder(x)  # z, mu, logvar, cls_encoding

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z, mu, logvar, cls_encoding = self.encode(x)
        x_hat = self.decode(z)
        return x_hat, mu, logvar, z, cls_encoding


if __name__ == "__main__":
    x = torch.randn(2, 1, 256, 128)
    model = ViT(
        image_size=(256, 128),
        patch_size=(8, 8),
        dim=512,
        out_dim=768,
        depth=6,
        heads=8,
        mlp_dim=1024,
        pool="cls",
        channels=1,
        dropout=0.0,
        emb_dropout=0.0,
        patch_dropout=0.25,
    )
    preds = model(x)
    print(preds.shape)  # (2, 256)

    vit_vae = ViTVAEEncoder(
        vit=model,
        image_size=(256, 128),
        patch_size=(8, 8),
        latent_channels=4,
        emb_dropout=0.0,
    )
    z, mu, logvar, cls_encoding = vit_vae(x)
    print(z.shape, mu.shape, logvar.shape, cls_encoding.shape)
    # (2, 4, 32, 16) (2, 4, 32, 16) (2, 4, 32, 16) (2, 768)

    decoder = Decoder(
        out_channels=1,
        latent_dim=4,
        channels=[128, 64, 32],
        num_heads=8,
        groups=16,
        dropout=0.1,
    )
    x_hat = decoder(z)
    print(x_hat.shape)  # (2, 1, 256, 128)

    vit_vae_full = ViTVAE(
        vit=model,
        image_size=(256, 128),
        patch_size=(8, 8),
        out_channels=1,
        latent_channels=4,
        channels=[128, 64, 32],
        decoder_heads=8,
        decoder_groups=16,
        dropout=0.1,
    )
    x_hat, mu, logvar, z, cls_encoding = vit_vae_full(x)
    print(
        x_hat.shape, mu.shape, logvar.shape, z.shape, cls_encoding.shape
    )  # (2, 1, 256, 128) (2, 4, 32, 16) (2, 4, 32, 16) (2, 4, 32, 16) (2, 768)
