import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Any
from collections import defaultdict
from melp.backbone.vit1d_cls import vit_nano, vit_tiny, vit_small, vit_middle, vit_base, vit_large, vit_xl


class PSGModalityEncoderCLS(nn.Module):
    """
    PSG Modality Encoder using vit1d_cls.py (with CLS token support).
    
    This encoder wraps the ViT backbone from vit1d_cls.py which includes:
    - CLS token support
    - forward_encoding() method that returns (cls, patches)
    - forward_from_tokens() method that takes tokens and adds CLS
    
    function：backbone → optional proj → L2-norm
    usage：emb = encoder(x)  # (bz, proj_out)
    """
    def __init__(self, *,
                 encoder_name: str,
                 proj_out: int      = 256,
                 proj_hidden: int   = 512,
                 freq: int          = 64,
                 win_sec: int       = 30,
                 channel: int       = 12, 
                 lead_wise = 0,
                 patch_size = 40, 
                 patch_size_ch = 4,
                 is_proj_head = 1,
                ):
        super().__init__()
        token_len  = freq * win_sec         # e.g. 64×30 = 1920

        self.token_len = token_len
        self.patch_size = patch_size
        
        # -------- build backbone using vit1d_cls --------
        if encoder_name == "vit_nano":
            self.backbone = vit_nano(num_leads=channel, seq_len = token_len, patch_size = patch_size, lead_wise = lead_wise, patch_size_ch = patch_size_ch)
        elif encoder_name == "vit_tiny":
            self.backbone = vit_tiny(num_leads=channel, seq_len = token_len, patch_size = patch_size, lead_wise = lead_wise, patch_size_ch = patch_size_ch)
        elif encoder_name == "vit_small":
            self.backbone = vit_small(num_leads=channel, seq_len = token_len, patch_size = patch_size, lead_wise = lead_wise, patch_size_ch = patch_size_ch)
        elif encoder_name == "vit_middle":
            self.backbone = vit_middle(num_leads=channel, seq_len = token_len, patch_size = patch_size, lead_wise = lead_wise, patch_size_ch = patch_size_ch)
        elif encoder_name == "vit_base":
            self.backbone = vit_base(num_leads=channel, seq_len = token_len, patch_size = patch_size, lead_wise = lead_wise, patch_size_ch = patch_size_ch)
        elif encoder_name == "vit_large":
            self.backbone = vit_large(num_leads=channel, seq_len = token_len, patch_size = patch_size, lead_wise = lead_wise, patch_size_ch = patch_size_ch)
        elif encoder_name == "vit_xl":
            self.backbone = vit_xl(num_leads=channel, seq_len = token_len, patch_size = patch_size, lead_wise = lead_wise, patch_size_ch = patch_size_ch)
        else:
            raise ValueError(f"Unknown encoder_name for CLS variant: {encoder_name}")

        d_model         = self.backbone.width
        self.downproj   = None          # ViT dont use 1×1 conv
        self.att_pool   = None
        if is_proj_head == 1:
            self.proj_head  = nn.Sequential(
                nn.Linear(d_model, proj_hidden),
                nn.LayerNorm(proj_hidden),
                nn.ReLU(inplace=True),
                nn.Linear(proj_hidden, proj_out),
                nn.LayerNorm(proj_out),
            )
        else:
            self.proj_head = None
    
    # ——— forward ———
    def forward(self, x, normalize = True, is_patch = False, use_avg_pool = False):
        """
        Args:
            x: (bz, C, L) - input PSG data
            normalize: whether to L2-normalize output
            is_patch: if True, returns patches; if False, returns CLS token or avg-pooled patches
            use_avg_pool: if True (and is_patch=False), returns average-pooled patch embeddings
                          instead of CLS token. Useful for ablation studies.
        Returns:
            if is_patch=True: patches (bz, N, D) - patch embeddings
            if is_patch=False and use_avg_pool=False: cls (bz, D) - CLS token embedding
            if is_patch=False and use_avg_pool=True: avg_emb (bz, D) - average of patch embeddings
        """
        if is_patch:
            # Return patch embeddings - need to compute full sequence
            _, patches = self.backbone.forward_encoding(x, return_sequence=False, add_lead_for_patch=False)
            # patches: (bz, N, D)
            h = patches
        elif use_avg_pool:
            # Return average-pooled patch embeddings (excluding CLS)
            h = self.backbone.forward_avg_pool(x)  # (bz, D)
            
            # Apply projection head if needed
            if self.proj_head is not None:
                h = self.proj_head(h)           # (bz, proj_out)
        else:
            # Return CLS token only - use forward() which is more efficient
            h = self.backbone(x)  # Uses forward() which returns only CLS, more memory efficient
            
            # Apply projection head if needed
            if self.proj_head is not None:
                h = self.proj_head(h)           # (bz, proj_out)

        if normalize:
            return F.normalize(h, dim=-1)
        else:
            return h

