from typing import Optional, Literal, Tuple
from argparse import Namespace
import torch
import torch.nn as nn
import torch.nn.functional as F

class NoEmbed(nn.Module):

    def __init__(self, add_cls_token) -> None:
        super().__init__()

        self.add_cls_token = add_cls_token
        if self.add_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.emb_dim))  # (1,1,D)

    def _to_BCT(self, x: torch.Tensor) -> torch.Tensor:
        # Convert to (B,C,T): (B,T)→(B,1,T); (B,T,C)→(B,C,T); (B,C,T)→unchanged
        if x.dim() == 2:
            x = x.unsqueeze(1)
        elif x.dim() == 3:
            B, A, B_or_T = x.shape
            if A > B_or_T:  # Case of (B,T,C)
                x = x.permute(0, 2, 1)
        else:
            raise ValueError(f"Unsupported input shape: {x.shape}")
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, T).
        Returns:
            feat: (B, num_frames (+1 if CLS), emb_dim)
        """
        x = self._to_BCT(x)  #
        B, C, T = x.shape
        assert C == 1, f"num_lead must be 1, but got {C}"


        if self.add_cls_token:
            cls = self.cls_token.expand(B, 1, 1)  # (B,1,D=1)
            x = torch.cat([cls, x], dim=1)             # (B, 1+T', D)

        x = x.permute(0, 2, 1) # (B, T, C=1)
        return x

class STFTEmbed(nn.Module):
    """
    STFT-based embedding module.

    Args:
        x: (B, num_lead, T) or (B, T) / (B, T, num_lead)
    Returns:
        feat: (B, num_frames (+CLS), emb_dim)
    """
    def __init__(
        self,
        params: Namespace,
        add_cls_token: bool = True,
        # STFT
        n_fft: Optional[int] = None,
        hop_length: Optional[int] = None,
        win_length: Optional[int] = None,
        window_type: str = "hann",
        use_log_mag: bool = False,
        eps: float = 1e-6,
        center: bool = True,
        normalized: bool = False,
        pad_mode: str = "reflect",
        # Frequency compression
        stft_out_dim: Optional[int] = None,   # If None, directly embed C*F without compression
    ) -> None:
        super().__init__()
        self.num_lead = int(params.num_lead)
        self.emb_dim  = int(params.emb_dim)
        self.add_cls_token = add_cls_token

        # --- STFT params ---
        self.n_fft      = int(n_fft or getattr(params, "n_fft", 256))
        self.hop_length = int(hop_length) if hop_length is not None else int(getattr(params, "hop_length", self.n_fft // 4))
        self.win_length = int(win_length) if win_length is not None else int(getattr(params, "win_length", self.n_fft))
        self.window_type= window_type
        self.use_log_mag= bool(use_log_mag or getattr(params, "use_log_mag", False))
        self.eps        = float(eps)
        self.center     = bool(center)
        self.normalized = bool(normalized)
        self.pad_mode   = pad_mode

        self.n_freq_bins = self.n_fft // 2 + 1

        # --- Frequency compression (linear projection F→D per lead) ---
        self.stft_out_dim = stft_out_dim if stft_out_dim is not None else getattr(params, "stft_out_dim", 0)
        if self.stft_out_dim and self.stft_out_dim > 0:
            self.freq_proj = nn.Linear(self.n_freq_bins, self.stft_out_dim)
            in_dim = self.num_lead * self.stft_out_dim
        else:
            self.freq_proj = None
            in_dim = self.num_lead * self.n_freq_bins

        # --- Final embedding layer ---
        self.embed = nn.Linear(in_dim, self.emb_dim)

        # --- Learnable CLS token ---
        if self.add_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.emb_dim))  # (1,1,D)

    def _build_window(self, device, dtype):
        if self.window_type == "hann":
            return torch.hann_window(self.win_length, device=device, dtype=dtype, periodic=True)
        if self.window_type == "hamming":
            return torch.hamming_window(self.win_length, device=device, dtype=dtype, periodic=True)
        return torch.hann_window(self.win_length, device=device, dtype=dtype, periodic=True)

    def _to_BCT(self, x: torch.Tensor) -> torch.Tensor:
        # Convert to (B,C,T): (B,T)→(B,1,T); (B,T,C)→(B,C,T); (B,C,T)→unchanged
        if x.dim() == 2:
            x = x.unsqueeze(1)
        elif x.dim() == 3:
            B, A, B_or_T = x.shape
            if A > B_or_T:  # Case of (B,T,C)
                x = x.permute(0, 2, 1)
        else:
            raise ValueError(f"Unsupported input shape: {x.shape}")
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, T).
        Returns:
            feat: (B, num_frames (+1 if CLS), emb_dim)
        """
        x = self._to_BCT(x)  # (B, C, T)
        B, C, T = x.shape
        assert C == self.num_lead, f"num_lead mismatch: got {C}, expected {self.num_lead}"

        window = self._build_window(x.device, x.dtype)

        # ---- Reshape (B,C,T) → (B*C, T) for STFT (supports only 1D/2D input)
        x_2d = x.reshape(B * C, T)  # (B*C, T)

        # STFT: (B*C, F, T')
        X = torch.stft(
            x_2d,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=window,
            center=self.center,
            return_complex=True,
            normalized=self.normalized,
            pad_mode=self.pad_mode,
        )
        # (B*C, F, T') → (B, C, F, T')
        Freq, Tp = X.shape[1], X.shape[2]
        X = X.view(B, C, Freq, Tp)

        mag = X.abs()
        if self.use_log_mag:
            mag = (mag + self.eps).log()

        # Frequency compression
        if self.freq_proj is not None:
            # (B, C, F, T') → (B*C*T', F)
            mag_flat = mag.permute(0, 1, 3, 2).contiguous().view(B * C * Tp, Freq)
            comp = self.freq_proj(mag_flat)                                 # (B*C*Tp, D)
            comp = comp.view(B, C, Tp, -1).permute(0, 2, 1, 3).contiguous() # (B, T', C, D)
            frame_feat = comp.view(B, Tp, C * comp.shape[-1])               # (B, T', C*D)
        else:
            frame_feat = mag.permute(0, 3, 1, 2).contiguous().view(B, Tp, C * Freq)  # (B, T', C*F)

        # Linear embedding
        feat = self.embed(frame_feat)  # (B, T', D)

        # Add learnable CLS token
        if self.add_cls_token:
            cls = self.cls_token.expand(B, 1, self.emb_dim)  # (B,1,D)
            feat = torch.cat([cls, feat], dim=1)             # (B, 1+T', D)

        return feat

class LinearEmbed(nn.Module):

    def __init__(
        self, 
        params: Namespace, 
        add_cls_token: bool=True,
        cls_token_learnable: bool = False,
        cls_token_init: str = "zeros"
    ) -> None:
        super(LinearEmbed, self).__init__()

        self.num_lead = params.num_lead
        self.chunk_len = int(params.lin_chunk_len)

        chunk_dim = int(self.num_lead * self.chunk_len)
        self.embed = nn.Linear(chunk_dim, params.emb_dim)

        self.add_cls_token = bool(add_cls_token)
        self.cls_token_learnable = bool(cls_token_learnable)

        if self.add_cls_token:
            self._setup_cls_token(chunk_dim, self.cls_token_learnable, cls_token_init)

    # ---- helpers -------------------------------------------------------------
    def _setup_cls_token(self, chunk_dim: int, learnable: bool, init: str) -> None:
        """Initialize CLS token; register as Parameter if learnable, otherwise as buffer."""
        t = torch.zeros(1, 1, chunk_dim)
        if init == "normal":
            nn.init.normal_(t, mean=0.0, std=0.02)
        elif init == "uniform":
            nn.init.uniform_(t, a=-0.02, b=0.02)
        else:
            nn.init.zeros_(t)

        if learnable:
            self.cls_token = nn.Parameter(t)
        else:
            self.register_buffer("cls_token", t, persistent=False)

    def _prepend_cls(self, x: torch.Tensor) -> torch.Tensor:
        """Prepend CLS token to (B, C, D); return unchanged if disabled."""
        if not self.add_cls_token:
            return x
        bs = x.size(0)
        cls = self.cls_token.to(device=x.device, dtype=x.dtype).expand(bs, 1, -1)
        return torch.cat((cls, x), dim=1)

    def forward(self, x: torch.Tensor):
        """
        Args:
            x (torch.Tensor): torch.Tensor of size (batch_size, num_lead, seqlen).
        Returns:
            feat (torch.Tensor): torch.Tensor of size (batch_size, num_chunks, emb_dim).
        """
        if x.dim() == 2:
            x = x.unsqueeze(1)
        elif x.dim() == 3:
            x = torch.swapaxes(x, 1, 2)
        else:
            raise

        assert x.size(1) == self.num_lead
        assert x.size(2) % self.chunk_len == 0

        bs = x.size(0)
        num_chunks = x.size(2) // self.chunk_len
        # batch_size, num_lead, num_chunks, chunk_len
        x = torch.reshape(x, (bs, self.num_lead, num_chunks, self.chunk_len))
        x = x.permute(0, 2, 1, 3)

        # batch_size, num_chunks, num_lead * chunk_len
        x = torch.reshape(x, (bs, num_chunks, -1))


        # # ADD CLS Token.
        # if self.add_cls_token:
        #     cls_token = torch.zeros(bs, 1, x.size(2)).to(x.device)
        #     x = torch.cat((cls_token, x), dim=1)
        x = self._prepend_cls(x)

        feat = self.embed(x)
        return feat
    
class STFTLinearFuse(nn.Module):
    """
    Module that fuses STFT features and Linear-transformed features.

    Args:
        x: (B, T) / (B, C, T) / (B, T, C)

    Returns:
        feat: (B, L_out (+CLS), D_out)
            - fuse='concat_time': D_out = emb_dim, L_out = T_stft + T_linear (+1 if add_cls)
            - fuse='concat_feat': D_out = emb_dim (aligned by linear projection 2D→D), L_out = max(T_stft, T_linear) (+1 if add_cls)
            - fuse='sum'/'mean': D_out = emb_dim, L_out = max(T_stft, T_linear) (+1 if add_cls)
    """

    def __init__(
        self,
        params: Namespace,
        *,
        # Fusion method
        fuse: Literal["concat_feat", "sum", "mean", "concat_time"] = "concat_feat",
        # Interpolation method to align sequence length (used except concat_time)
        align_mode: Literal["linear", "nearest"] = "linear",
        # Whether to project 2D → D for concat_feat
        project_out: bool = True,
        # CLS token configuration for final output
        add_cls_token: bool = True,
        cls_token_learnable: bool = False,
        cls_token_init: str = "zeros",
        # Additional arguments for sub-modules
        stft_kwargs: Optional[dict] = None,
        linear_kwargs: Optional[dict] = None,
    ) -> None:
        super().__init__()
        stft_kwargs = stft_kwargs or {}
        linear_kwargs = linear_kwargs or {}

        self.emb_dim = int(params.emb_dim)
        self.fuse = fuse
        self.align_mode = align_mode
        self.project_out = project_out
        self.add_cls_token = bool(add_cls_token)
        self.cls_token_learnable = bool(cls_token_learnable)

        # Generate sub-modules without CLS (add only one after fusion)
        self.stft = STFTEmbed(params, add_cls_token=False, **stft_kwargs)
        self.linear = LinearEmbed(params, add_cls_token=False, **linear_kwargs)

        # Projection layer for concat_feat (2D → D)
        if self.fuse == "concat_feat" and self.project_out:
            self.fuse_proj = nn.Linear(self.emb_dim * 2, self.emb_dim)
        else:
            self.fuse_proj = None

        if self.add_cls_token:
            self._setup_cls_token(self.emb_dim, self.cls_token_learnable, cls_token_init)

    # ---- helpers -------------------------------------------------------------
    def _setup_cls_token(self, dim: int, learnable: bool, init: str) -> None:
        t = torch.zeros(1, 1, dim)
        if init == "normal":
            nn.init.normal_(t, mean=0.0, std=0.02)
        elif init == "uniform":
            nn.init.uniform_(t, a=-0.02, b=0.02)
        else:
            nn.init.zeros_(t)

        if learnable:
            self.cls_token = nn.Parameter(t)
        else:
            self.register_buffer("cls_token", t, persistent=False)

    def _prepend_cls(self, x: torch.Tensor) -> torch.Tensor:
        if not self.add_cls_token:
            return x
        bs = x.size(0)
        cls = self.cls_token.to(device=x.device, dtype=x.dtype).expand(bs, 1, -1)
        return torch.cat((cls, x), dim=1)

    @staticmethod
    def _interpolate_time(seq: torch.Tensor, target_len: int, mode: str) -> torch.Tensor:
        """
        Interpolate sequence (B, T, D) to target_len in time dimension.
        """
        if seq.size(1) == target_len:
            return seq
        # F.interpolate expects (N,C,L), so swap axes
        x = seq.permute(0, 2, 1)  # (B, D, T)
        if mode == "linear":
            x = F.interpolate(x, size=target_len, mode="linear", align_corners=False)
        elif mode == "nearest":
            x = F.interpolate(x, size=target_len, mode="nearest")
        else:
            raise ValueError(f"Unsupported align_mode: {mode}")
        return x.permute(0, 2, 1)  # (B, target_len, D)

    def _align_pair(self, a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Align sequence length of a,b: (B, Ta, D), (B, Tb, D) -> max(Ta, Tb).
        """
        Ta, Tb = a.size(1), b.size(1)
        if Ta == Tb:
            return a, b
        L = max(Ta, Tb)
        a = self._interpolate_time(a, L, self.align_mode)
        b = self._interpolate_time(b, L, self.align_mode)
        return a, b

    # ---- forward -------------------------------------------------------------
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, T) / (B, C, T) / (B, T, C)
        Returns:
            (B, L_out (+CLS), D_out)
        """
        # (B, Ts, D), (B, Tl, D)
        stft_feat = self.stft(x)
        lin_feat  = self.linear(x)

        if self.fuse == "concat_time":
            # Concatenate along time dimension (without alignment)
            out = torch.cat([stft_feat, lin_feat], dim=1)  # (B, Ts+Tl, D)
        else:
            # Align sequence length before fusion
            stft_feat, lin_feat = self._align_pair(stft_feat, lin_feat)

            if self.fuse == "concat_feat":
                out = torch.cat([stft_feat, lin_feat], dim=-1)  # (B, L, 2D)
                if self.fuse_proj is not None:
                    out = self.fuse_proj(out)                   # (B, L, D)
            elif self.fuse == "sum":
                out = stft_feat + lin_feat                      # (B, L, D)
            elif self.fuse == "mean":
                out = 0.5 * (stft_feat + lin_feat)              # (B, L, D)
            else:
                raise ValueError(f"Unsupported fuse: {self.fuse}")

        # Add final CLS token (only one at top level)
        out = self._prepend_cls(out)
        return out
