#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""channel_vision_transformer.py

Wavelet‑enhanced Channel‑Wise Vision Transformer for 1‑D Physiological Signals
-----------------------------------------------------------------------------
This module integrates a learnable wavelet decomposition front‑end with a
frequency‑aware masking strategy and a Rotary‑position‑encoded Transformer
backbone.  It is designed for large‑scale **self‑supervised or supervised**
pre‑training on multichannel time‑series such as ECG, EMG, or EEG.

Key Features
============
1. **Learnable WaveFilter** – each Conv1D kernel is initialised from a chosen
   PyWavelets basis (e.g., "db6") and then fine‑tuned during training.
2. **AdaptiveWaveletSelector** – dynamically blends multiple wavelet bases
   according to the input’s global statistics, enabling signal‑adaptive
   time‑frequency decomposition.
3. **CrossScaleCAFFN** – Cross‑scale Channel‑Aggregation FFN with residual
   attention fusion for robust multi‑resolution feature mixing.
4. **Patch‑level Frequency‑Guided Masking** – selects patches for masking
   proportionally to their spectral energy to focus learning on informative
   bands.
5. **Rotary Position Embedding (RoPE)** – parameter‑free positional encoding
   that naturally generalises to sequences longer than those seen in training.
"""

from __future__ import annotations

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt

from models.modules.patching2D import PatchEmbed2D 
from models.modules.attention import CustomAttentionBlock
from models.modules.rotary import RotaryPositionEmbedding

################################################################################
# 1) Learnable WaveFilter
################################################################################

def load_wavelet_kernel(wave_name: str, kernel_size: int):
    """Return low‑pass / high‑pass kernels derived from a PyWavelet basis.

    The kernel is linearly interpolated (if necessary) to match *kernel_size*
    while approximately preserving the original wavelet’s energy.
    """
    wave = pywt.Wavelet(wave_name)  # e.g., "db6", "sym4", ...
    dec_lo = torch.tensor(wave.dec_lo, dtype=torch.float32)
    dec_hi = torch.tensor(wave.dec_hi, dtype=torch.float32)

    def _resize_and_norm(kernel: torch.Tensor) -> torch.Tensor:
        if kernel.numel() != kernel_size:
            # Interpolate to desired length (1‑D linear interpolation)
            kernel = F.interpolate(kernel.view(1, 1, -1), size=kernel_size,
                                   mode="linear", align_corners=True).squeeze()
            # Renormalise energy
            kernel *= (kernel_size * 0.5) / (kernel.abs().sum() + 1e-8)
        return kernel

    return _resize_and_norm(dec_lo), _resize_and_norm(dec_hi)


class LearnableWaveFilter(nn.Module):
    """Depthwise 1‑D convolution layer initialised with wavelet coefficients."""

    def __init__(
        self,
        in_ch: int = 8,
        kernel_size: int = 16,
        wave_init: str = "db6",
        separate_per_channel: bool = True,
    ) -> None:
        """Create a learnable wavelet filter bank.

        Args:
            in_ch:   Number of input channels.
            kernel_size: Length of 1‑D convolution kernels.
            wave_init: Name of the wavelet basis to initialise weights.
            separate_per_channel: If *True* each channel gets its own kernel
                                  (depthwise convolution).
        """
        super().__init__()
        self.in_ch = in_ch

        low_init, high_init = load_wavelet_kernel(wave_init, kernel_size)
        groups = in_ch if separate_per_channel else 1

        self.low_filter = nn.Conv1d(
            in_ch, in_ch, kernel_size, stride=1, padding=kernel_size // 2,
            groups=groups, bias=False
        )
        self.high_filter = nn.Conv1d(
            in_ch, in_ch, kernel_size, stride=1, padding=kernel_size // 2,
            groups=groups, bias=False
        )

        # Initialise weights with wavelet coefficients (channel‑wise if needed)
        with torch.no_grad():
            if separate_per_channel:
                for c in range(in_ch):
                    self.low_filter.weight.data[c, 0, :] = low_init
                    self.high_filter.weight.data[c, 0, :] = high_init
            else:
                self.low_filter.weight.data[0, 0, :] = low_init
                self.high_filter.weight.data[0, 0, :] = high_init

    def forward(self, x: torch.Tensor):
        """Return approximation and detail coefficients downsampled by 2."""
        approx = self.low_filter(x)[..., ::2]
        detail = self.high_filter(x)[..., ::2]
        return approx, detail


class AdaptiveWaveletSelector(nn.Module):
    """Blend multiple learnable wavelet filters using content‑aware weights."""

    def __init__(
        self,
        in_ch: int = 8,
        wavelet_names: list[str] | None = None,
        kernel_size: int = 16,
        separate_per_channel: bool = True,
    ) -> None:
        super().__init__()
        if wavelet_names is None:
            wavelet_names = ["db4", "db6", "sym4", "coif3"]

        self.filters = nn.ModuleList([
            LearnableWaveFilter(
                in_ch, kernel_size, wname, separate_per_channel
            )
            for wname in wavelet_names
        ])
        self.selector = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),  # global channel statistics
            nn.Flatten(),
            nn.Linear(in_ch, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, len(wavelet_names)),
            nn.Softmax(dim=1),
        )

    def forward(self, x: torch.Tensor):
        B, C, T = x.shape
        weights = self.selector(x)  # (B, num_wavelets)
        approx_sum = torch.zeros_like(x)
        detail_sum = torch.zeros_like(x)

        # Weighted sum of multi‑wavelet outputs
        for i, wf in enumerate(self.filters):
            approx_i, detail_i = wf(x)
            w_i = weights[:, i].view(B, 1, 1)
            approx_sum += approx_i * w_i
            detail_sum += detail_i * w_i
        return approx_sum, detail_sum


################################################################################
# 2) Feed‑Forward Networks (Channel Aggregation + Cross‑Scale Attention)
################################################################################

class ElementScale(nn.Module):
    """Learnable per‑element scaling parameter."""

    def __init__(self, shape, init_value: float = 1.0):
        super().__init__()
        self.scale = nn.Parameter(torch.full(shape, init_value))

    def forward(self, x: torch.Tensor):
        return x * self.scale


class ChannelAggregationFFN(nn.Module):
    """FFN block with depthwise convolution and decomposition regulariser."""

    def __init__(
        self,
        embed_dims: int,
        ffn_ratio: float = 4.0,
        kernel_size: int = 3,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        hidden_dims = int(embed_dims * ffn_ratio)
        self.fc1 = nn.Conv2d(embed_dims, hidden_dims, kernel_size=1)
        self.dwconv = nn.Conv2d(hidden_dims, hidden_dims, kernel_size,
                                padding=kernel_size // 2, groups=hidden_dims)
        self.act = nn.GELU()
        self.drop = nn.Dropout(dropout)
        self.fc2 = nn.Conv2d(hidden_dims, embed_dims, kernel_size=1)

        # Decomposition term to encourage locality
        self.decompose = nn.Conv2d(hidden_dims, 1, kernel_size=1)
        self.decompose_act = nn.GELU()
        self.sigma = ElementScale([1, hidden_dims, 1, 1], init_value=1e-5)

    def forward(self, x: torch.Tensor):
        out = self.fc1(x)
        out = self.dwconv(out)
        out = self.act(out)
        out = self.drop(out)

        # Decomposition regularisation
        t = self.decompose_act(self.decompose(out))
        out = out - t
        out = self.sigma(out) + out

        out = self.fc2(out)
        return self.drop(out)


class CrossScaleCAFFN(nn.Module):
    """Channel Aggregation FFN + cross‑scale attention fusion."""

    def __init__(
        self,
        embed_dims: int,
        ffn_ratio: float = 4.0,
        kernel_size: int = 3,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.base_ffn = ChannelAggregationFFN(embed_dims, ffn_ratio,
                                              kernel_size, dropout)
        self.cross_attn = nn.MultiheadAttention(embed_dims, num_heads=4,
                                                batch_first=True)
        self.attn_scale = nn.Parameter(torch.tensor(0.1))

    def forward(self, x: torch.Tensor, prev_feats: list[torch.Tensor] | None = None):
        out = self.base_ffn(x)
        if prev_feats:
            B, C, H, W = out.shape
            q = out.permute(0, 2, 3, 1).reshape(B, -1, C)  # (B, HW, C)
            k = torch.cat([
                pf.permute(0, 2, 3, 1).reshape(pf.size(0), -1, C)
                for pf in prev_feats
            ], dim=1)  # (B, ΣHW, C)
            attn_out, _ = self.cross_attn(q, k, k)
            attn_out = attn_out.view(B, H, W, C).permute(0, 3, 1, 2)
            out = out + self.attn_scale * attn_out
        return out


################################################################################
# 3) Multi‑Head Gate + Wavelet Decomposition with Soft Gating
################################################################################

class MultiHeadGate(nn.Module):
    """Channel‑wise gating weights predicted via self‑attention over averages."""

    def __init__(self, in_channels: int, num_heads: int = 4):
        super().__init__()
        self.num_heads = num_heads
        self.q_proj = nn.Linear(in_channels, in_channels)
        self.k_proj = nn.Linear(in_channels, in_channels)
        self.v_proj = nn.Linear(in_channels, in_channels)

    def forward(self, x: torch.Tensor):
        B, C, _ = x.shape
        x_pool = x.mean(dim=-1)  # (B, C)
        head_dim = C // self.num_heads

        Q = self.q_proj(x_pool).view(B, self.num_heads, head_dim)
        K = self.k_proj(x_pool).view(B, self.num_heads, head_dim)
        V = self.v_proj(x_pool).view(B, self.num_heads, head_dim)

        attn = (Q @ K.transpose(1, 2)) / math.sqrt(head_dim)
        attn = F.softmax(attn, dim=-1)
        out_v = attn @ V  # (B, num_heads, head_dim)
        gate = torch.sigmoid(out_v.view(B, C))  # (B, C)
        return gate.unsqueeze(-1)  # (B, C, 1)


class SoftGateWaveletDecomp(nn.Module):
    """Iterative wavelet decomposition with gating and cross‑scale mixing."""

    def __init__(
        self,
        in_channels: int = 8,
        max_level: int = 3,
        kernel_size: int = 16,
        wavelet_names: list[str] | None = None,
        use_separate_channel: bool = True,
        ffn_ratio: float = 4.0,
        ffn_kernel_size: int = 5,
        ffn_drop: float = 0.1,
    ) -> None:
        super().__init__()
        self.max_level = max_level
        self.in_channels = in_channels

        self.wave_filter = AdaptiveWaveletSelector(
            in_ch=in_channels,
            kernel_size=kernel_size,
            wavelet_names=wavelet_names,
            separate_per_channel=use_separate_channel,
        )
        self.gate = MultiHeadGate(in_channels)

        self.sub_ffn = nn.ModuleList([
            CrossScaleCAFFN(
                embed_dims=2 * in_channels,
                ffn_ratio=ffn_ratio,
                kernel_size=ffn_kernel_size,
                dropout=ffn_drop,
            )
            for _ in range(max_level)
        ])
        self.res_scale = nn.ParameterList([
            nn.Parameter(torch.zeros(1)) for _ in range(max_level)
        ])

    def forward(self, x: torch.Tensor):
        B, C, T = x.shape
        detail_accum = torch.zeros_like(x)
        approx = x
        freq_bands = []
        prev_feats = []

        for lvl in range(self.max_level):
            approx_new, detail_new = self.wave_filter(approx)

            # Upsample back to original length for gating
            up_approx = F.interpolate(approx_new, size=T, mode="nearest")
            up_detail = F.interpolate(detail_new, size=T, mode="nearest")

            gate = self.gate(approx)  # (B, C, 1)
            approx = gate * approx + (1 - gate) * up_approx
            detail_accum = gate * detail_accum + (1 - gate) * up_detail

            # Cross‑scale FFN over concatenated (approx, detail)
            subband = torch.cat([approx, detail_accum], dim=1)
            sb_2d = subband.unsqueeze(2)  # add dummy spatial dim (B, 2C, 1, T)
            out = self.sub_ffn[lvl](sb_2d, prev_feats)
            out = sb_2d + self.res_scale[lvl] * out
            out = out.squeeze(2)

            approx, detail_accum = out[:, :C], out[:, C:]
            freq_bands.append(detail_accum)
            prev_feats.append(sb_2d)

        freq_bands.append(approx)  # final approximation band
        return torch.cat(freq_bands, dim=1)  # (B, (L+1)C, T)


################################################################################
# 4) Frequency‑Guided Patch Masking
################################################################################

def patch_level_frequency_guided_masking(
    x: torch.Tensor,
    patch_width: int,
    mask_ratio: float = 0.4,
    importance_ratio: float = 0.6,
):
    """Select patches to mask based on relative spectral energy.

    Args:
        x:            Input signal (B, C, T)
        patch_width:  Number of timesteps per patch (must divide T).
        mask_ratio:   Fraction of patches to mask (0‑1).
        importance_ratio: Weight given to spectral energy vs. random noise.
    Returns:
        patches_masked: Patches with masked regions zeroed (B, CP, W).
        mask:           Binary mask where 1 = masked   (B, CP)
        ids_restore:    Indices to restore original order after shuffling.
    """
    B, C, T = x.shape
    N = T // patch_width  # patches per channel
    total = C * N

    patches = x.view(B, C, N, patch_width).reshape(B, total, patch_width)
    energy = torch.fft.rfft(patches, dim=2).abs().sum(dim=2)  # (B, total)
    importance = energy / (energy.sum(dim=1, keepdim=True) + 1e-6)
    noise = torch.rand_like(importance)
    scores = (1 - importance_ratio) * noise + importance_ratio * importance

    keep = int(total * (1 - mask_ratio))
    ids_sorted = torch.argsort(scores, dim=1)
    ids_keep = ids_sorted[:, :keep]
    ids_restore = ids_sorted.argsort(dim=1)

    mask = torch.ones(B, total, device=x.device)
    mask.scatter_(1, ids_keep, 0)

    patches_masked = torch.zeros_like(patches)
    for b in range(B):
        patches_masked[b, ids_keep[b]] = patches[b, ids_keep[b]]

    return patches_masked, mask, ids_restore


################################################################################
# 5) Channel Vision Transformer – main model
################################################################################

class ChannelVisionTransformer(nn.Module):
    """Vision Transformer adapted for multi‑channel 1‑D signals."""

    def __init__(
        self,
        # Wavelet front‑end
        in_ch: int = 16,
        max_level: int = 3,
        wave_kernel_size: int = 16,
        wavelet_names: list[str] | None = None,
        use_separate_channel: bool = True,
        # Cross‑scale FFN params
        ffn_ratio: float = 4.0,
        ffn_kernel_size: int = 5,
        ffn_drop: float = 0.1,
        # Patch params
        patch_width: int = 64,
        # Transformer params
        embed_dim: int = 512,
        depth: int = 12,
        num_heads: int = 8,
        mlp_ratio: float = 4.0,
        norm_layer=nn.LayerNorm,
        drop_path: float = 0.1,
        attention_type: str = "default",
        masking_ratio: float = 0.75,
        importance_ratio: float = 0.6,
        use_masking: bool = True,
        # RoPE
        max_seq_len: int = 2048,
    ) -> None:
        super().__init__()
        self.use_masking = use_masking
        self.patch_width = patch_width
        self.masking_ratio = masking_ratio
        self.importance_ratio = importance_ratio

        # Wavelet decomposition front‑end
        self.wavelet_decomp = SoftGateWaveletDecomp(
            in_channels=in_ch,
            max_level=max_level,
            kernel_size=wave_kernel_size,
            wavelet_names=wavelet_names,
            use_separate_channel=use_separate_channel,
            ffn_ratio=ffn_ratio,
            ffn_kernel_size=ffn_kernel_size,
            ffn_drop=ffn_drop,
        )
        self.wave_decomp_ch = (max_level + 1) * in_ch

        # Patch projection
        self.patch_to_embedding = nn.Linear(patch_width, embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.rotary_pos_embed = RotaryPositionEmbedding(embed_dim, max_seq_len)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            CustomAttentionBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=True,
                norm_layer=norm_layer,
                attention_type=attention_type,
                block_idx=i,
                num_channels=self.wave_decomp_ch,
                drop_path=drop_path,
            )
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)
        self._init_weights()

    # ---------------------------------------------------------------------
    # Initialisation helpers
    # ---------------------------------------------------------------------
    def _init_weights(self):
        self.apply(self._init_module)
        nn.init.normal_(self.mask_token, std=0.02)
        nn.init.xavier_uniform_(self.patch_to_embedding.weight)
        if self.patch_to_embedding.bias is not None:
            nn.init.zeros_(self.patch_to_embedding.bias)

    @staticmethod
    def _init_module(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    # ---------------------------------------------------------------------
    # Forward pass
    # ---------------------------------------------------------------------
    def forward(self, x: torch.Tensor):
        """Forward pass.

        Args:
            x (Tensor): Input of shape (B, C, T).

        Returns:
            If *use_masking* is True: (x_out, mask, ids_restore)
            else: x_out only.
        """
        # Wavelet time‑frequency representation
        wave_spec = self.wavelet_decomp(x)  # (B, C', T)
        B, C, T = wave_spec.shape
        N = T // self.patch_width  # patches per channel
        total = C * N

        # ------------------------------------------------------------------
        # Masking (optional)
        # ------------------------------------------------------------------
        if self.use_masking:
            _, mask, ids_restore = patch_level_frequency_guided_masking(
                wave_spec, self.patch_width, self.masking_ratio,
                self.importance_ratio,
            )
        else:
            mask = torch.zeros(B, total, device=x.device)
            ids_restore = torch.arange(total, device=x.device).unsqueeze(0).repeat(B, 1)

        # Patch projection
        patches = wave_spec.view(B, C, N, self.patch_width).reshape(B, total, self.patch_width)
        tokens = self.patch_to_embedding(patches)  # (B, total, D)

        # Replace masked tokens
        mask_token_expanded = self.mask_token.expand(B, total, -1)
        tokens = torch.where(mask.unsqueeze(-1) == 1, mask_token_expanded, tokens)

        # Positional encoding and Transformer
        tokens = self.rotary_pos_embed(tokens)
        for blk in self.blocks:
            tokens = blk(tokens)
        tokens = self.norm(tokens)

        if self.use_masking:
            return tokens, mask, ids_restore
        return tokens
