#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""mae_decoder.py

MAE‑style decoder for Channel‑Wise Vision Transformer
====================================================
This module reconstructs masked 1‑D patches produced by the encoder.  It treats
all tokens (masked + visible) as a single sequence, relying on positional
information that is already embedded by the encoder; hence, no extra positional
embedding is introduced here.
"""

from __future__ import annotations

import torch
import torch.nn as nn

# Local import – expected to be available in the codebase
from models.modules.attention import CustomAttentionBlock


class MAEDecoder(nn.Module):
    """Transformer decoder that predicts raw patch content from encoder tokens."""

    def __init__(
        self,
        # Encoder interface
        decoder_output_dim: int | None = None,  # defaults to patch_width if None
        embed_dim: int = 512,                   # encoder embedding dimension
        # Decoder hyper‑parameters
        decoder_embed_dim: int = 512,
        decoder_num_heads: int = 8,
        decoder_depth: int = 6,
        mlp_ratio: float = 4.0,
        norm_layer: type[nn.LayerNorm] = nn.LayerNorm,
        attention_type: str = "default",
        drop_path: float = 0.0,
        # Patch configuration
        patch_width: int = 64,
    ) -> None:
        """Create a light BERT‑like decoder.

        Args:
            decoder_output_dim: Dimensionality of each reconstructed patch
                                 (defaults to *patch_width*).
            embed_dim:           Dimensionality of encoder tokens fed to decoder.
            decoder_embed_dim:   Internal embedding size of the decoder.
            decoder_num_heads:   Number of attention heads per block.
            decoder_depth:       Number of Transformer blocks.
            mlp_ratio:           Width multiplier for MLP hidden layer.
            norm_layer:          Normalisation layer type (e.g., LayerNorm).
            attention_type:      CustomAttentionBlock attention variant.
            drop_path:           Stochastic depth probability.
            patch_width:         Temporal size of each patch (ground‑truth).
        """
        super().__init__()

        self.patch_width = patch_width
        self.decoder_output_dim = decoder_output_dim or patch_width

        # ------------------------------------------------------------------
        # Layers
        # ------------------------------------------------------------------
        # 1) Token projection from encoder dim → decoder dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        # 2) Transformer decoder blocks (no additional positional encoding)
        self.decoder_blocks = nn.ModuleList([
            CustomAttentionBlock(
                dim=decoder_embed_dim,
                num_heads=decoder_num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=True,
                norm_layer=norm_layer,
                attention_type=attention_type,
                block_idx=i,
                drop_path=drop_path,
            )
            for i in range(decoder_depth)
        ])
        self.decoder_norm = norm_layer(decoder_embed_dim)

        # 3) Final projection back to raw patch dimension
        self.decoder_pred = nn.Linear(decoder_embed_dim, self.decoder_output_dim, bias=True)

        self._init_weights()

    # ----------------------------------------------------------------------
    # Weight initialisation helpers
    # ----------------------------------------------------------------------
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    # ----------------------------------------------------------------------
    # Forward pass
    # ----------------------------------------------------------------------
    def forward(self, x: torch.Tensor, ids_restore: torch.Tensor):
        """Reconstruct masked patches.

        Args:
            x:           Encoder output tokens of shape (B, P, *embed_dim*), where
                         *P = C'·N* is the flattened patch sequence.
            ids_restore: Indices allowing the caller to reorder patches back to
                         their original temporal/channel order (unused here but
                         returned by the encoder for downstream tasks).

        Returns:
            Tensor of shape (B, P, *decoder_output_dim*) representing the
            predicted raw (or spectrally transformed) patch content.
        """
        # Linear projection to decoder embedding space
        x = self.decoder_embed(x)

        # Transformer decoding (positional info already present)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # Project to patch dimension
        return self.decoder_pred(x)
