"""
Biosignals-Text Contrastive Learning Model

A multimodal model for biosignal-text alignment combining:
- Conv-Transformer encoder for biosignals (time series)
- Transformer text encoder  
- Transformer text decoder for caption generation
- Signal reconstruction decoder for self-supervised learning
"""

from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass


# ============================================================================
# Configuration Classes
# ============================================================================

@dataclass
class BiosignalsCfg:
    """Configuration for biosignals encoder"""
    input_channels: int = 12  # Number of input channels
    signal_length: int = 3840  # Length of input time series
    sampling_rate: int = 128  # Sampling rate in Hz
    
    # Conv layers configuration
    conv_layers: list = None  # Conv layer dimensions
    kernel_sizes: list = None  # Kernel sizes for conv layers
    strides: list = None  # Strides for conv layers
    
    # Transformer parameters
    transformer_layers: int = 6
    transformer_width: int = 768
    transformer_heads: int = 12
    mlp_ratio: float = 4.0
    
    # Pooling and output
    pool_type: str = 'attn'  # 'avg', 'max', 'cls', 'attn'
    dropout: float = 0.1
    
    def __post_init__(self):
        if self.conv_layers is None:
            self.conv_layers = [64, 128, 256, 512]
        if self.kernel_sizes is None:
            self.kernel_sizes = [7, 5, 3, 3]
        if self.strides is None:
            self.strides = [2, 2, 2, 2]


@dataclass
class TextCfg:
    """Configuration for text encoder/decoder"""
    vocab_size: int = 50257
    context_length: int = 256
    width: int = 768
    heads: int = 12
    layers: int = 12
    mlp_ratio: float = 4.0


@dataclass
class DecoderCfg:
    """Configuration for multimodal decoder"""
    context_length: int = 256
    width: int = 768
    heads: int = 12
    layers: int = 6
    mlp_ratio: float = 4.0


# ============================================================================
# Basic Building Blocks
# ============================================================================

class LayerNorm(nn.LayerNorm):
    """LayerNorm with optional fp16 support"""
    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        return x.to(orig_type)


class QuickGELU(nn.Module):
    """Fast GELU approximation"""
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class Conv1dBlock(nn.Module):
    """1D Convolutional block with normalization and activation"""
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dropout=0.1):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size, 
            stride=stride, padding=kernel_size//2
        )
        self.norm = nn.BatchNorm1d(out_channels)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """Transformer block with self-attention and MLP"""
    
    def __init__(
        self, 
        width: int, 
        heads: int, 
        mlp_ratio: float = 4.0,
        dropout: float = 0.1
    ):
        super().__init__()
        self.attention = nn.MultiheadAttention(width, heads, dropout=dropout, batch_first=True)
        self.ln_1 = LayerNorm(width)
        self.mlp = nn.Sequential(
            nn.Linear(width, int(width * mlp_ratio)),
            QuickGELU(),
            nn.Dropout(dropout),
            nn.Linear(int(width * mlp_ratio), width),
            nn.Dropout(dropout)
        )
        self.ln_2 = LayerNorm(width)
        
    def forward(self, x, attn_mask=None):
        # Self-attention with residual
        attn_out, _ = self.attention(x, x, x, attn_mask=attn_mask)
        x = x + attn_out
        x = self.ln_1(x)
        
        # MLP with residual
        x = x + self.mlp(x)
        x = self.ln_2(x)
        
        return x


# ============================================================================
# Biosignals Encoder (Conv-Transformer)
# ============================================================================

class BiosignalsEncoder(nn.Module):
    """
    Biosignals encoder using 1D convolutions followed by transformer layers.
    Converts multi-channel time series to embeddings.
    """
    
    def __init__(
        self, 
        cfg: BiosignalsCfg,
        embed_dim: int = 512,
        output_tokens: bool = False,
    ):
        super().__init__()
        self.cfg = cfg
        self.embed_dim = embed_dim
        self.output_tokens = output_tokens
        self.pool_type = cfg.pool_type
        
        # Convolutional feature extraction
        conv_layers = []
        in_channels = cfg.input_channels
        
        for out_channels, kernel_size, stride in zip(
            cfg.conv_layers, cfg.kernel_sizes, cfg.strides
        ):
            conv_layers.append(
                Conv1dBlock(in_channels, out_channels, kernel_size, stride, cfg.dropout)
            )
            in_channels = out_channels
            
        self.conv_layers = nn.Sequential(*conv_layers)
        
        # Calculate output length after convolutions
        with torch.no_grad():
            dummy = torch.randn(1, cfg.input_channels, cfg.signal_length)
            dummy_out = self.conv_layers(dummy)
            self.conv_output_length = dummy_out.shape[2]
        
        self.conv_output_dim = cfg.conv_layers[-1]
        
        # Project to transformer dimension
        self.proj_conv = nn.Linear(self.conv_output_dim, cfg.transformer_width)
        
        # Positional embeddings
        self.pos_embed = nn.Parameter(
            torch.randn(1, self.conv_output_length, cfg.transformer_width)
        )
        
        # Class token
        self.cls_token = nn.Parameter(
            torch.randn(1, 1, cfg.transformer_width)
        )
        
        # Transformer layers
        self.transformer = nn.ModuleList([
            TransformerBlock(
                cfg.transformer_width,
                cfg.transformer_heads,
                cfg.mlp_ratio,
                cfg.dropout
            )
            for _ in range(cfg.transformer_layers)
        ])
        
        self.ln_final = LayerNorm(cfg.transformer_width)
        
        # Output projection
        self.proj_out = nn.Linear(cfg.transformer_width, embed_dim)
        
        # Attention pooling
        if self.pool_type == 'attn':
            self.attn_pool = nn.MultiheadAttention(
                cfg.transformer_width,
                cfg.transformer_heads,
                batch_first=True
            )
            
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            x: (batch_size, channels, signal_length)
        Returns:
            embedding: (batch_size, embed_dim)
            tokens: Optional (batch_size, seq_len, transformer_width) for decoder
        """
        batch_size = x.shape[0]
        
        # Conv feature extraction
        x = self.conv_layers(x)  # (B, conv_dim, conv_len)
        x = x.transpose(1, 2)  # (B, conv_len, conv_dim)
        x = self.proj_conv(x)  # (B, conv_len, width)
        
        # Add positional embeddings
        x = x + self.pos_embed
        
        # Add class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([x, cls_tokens], dim=1)
        
        # Transformer layers
        for layer in self.transformer:
            x = layer(x)
            
        x = self.ln_final(x)
        
        # Pooling
        if self.pool_type == 'cls':
            pooled = x[:, -1]
        elif self.pool_type == 'avg':
            pooled = x[:, :-1].mean(dim=1)
        elif self.pool_type == 'max':
            pooled = x[:, :-1].max(dim=1)[0]
        elif self.pool_type == 'attn':
            query = x[:, -1:]
            pooled, _ = self.attn_pool(query, x[:, :-1], x[:, :-1])
            pooled = pooled.squeeze(1)
        
        # Project to output dimension
        embedding = self.proj_out(pooled)
        
        if self.output_tokens:
            return embedding, x[:, :-1]  # Exclude CLS token for decoder
        return embedding


# ============================================================================
# Text Encoder
# ============================================================================

class TextEncoder(nn.Module):
    """Transformer-based text encoder"""
    
    def __init__(self, cfg: TextCfg, embed_dim: int = 512):
        super().__init__()
        self.cfg = cfg
        
        self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.width)
        self.pos_embedding = nn.Parameter(torch.randn(cfg.context_length, cfg.width))
        
        self.transformer = nn.ModuleList([
            TransformerBlock(cfg.width, cfg.heads, cfg.mlp_ratio)
            for _ in range(cfg.layers)
        ])
        
        self.ln_final = LayerNorm(cfg.width)
        self.proj = nn.Linear(cfg.width, embed_dim)
        
        # Causal mask
        self.register_buffer(
            "causal_mask",
            torch.triu(torch.ones(cfg.context_length, cfg.context_length), diagonal=1).bool()
        )
        
    def forward(self, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            text: (batch_size, seq_len) token ids
        Returns:
            text_features: (batch_size, embed_dim)
            token_embeddings: (batch_size, seq_len, width)
        """
        seq_len = text.shape[1]
        
        x = self.token_embedding(text)
        x = x + self.pos_embedding[:seq_len]
        
        # Causal attention mask
        mask = self.causal_mask[:seq_len, :seq_len]
        
        for layer in self.transformer:
            x = layer(x, attn_mask=mask)
            
        x = self.ln_final(x)
        
        # Get features from last token (EOS position)
        text_features = self.proj(x[:, -1])
        
        return text_features, x


# ============================================================================
# Multimodal Decoder
# ============================================================================

class MultimodalDecoder(nn.Module):
    """
    Multimodal decoder for caption generation.
    Uses cross-attention between text and biosignal tokens.
    """
    
    def __init__(
        self, 
        cfg: DecoderCfg, 
        vocab_size: int,
        biosignal_width: int = 768
    ):
        super().__init__()
        self.cfg = cfg
        
        self.token_embedding = nn.Embedding(vocab_size, cfg.width)
        self.pos_embedding = nn.Parameter(torch.randn(cfg.context_length, cfg.width))
        
        # Project biosignal tokens to decoder width
        self.biosignal_proj = nn.Linear(biosignal_width, cfg.width)
        
        # Decoder layers with cross-attention
        self.layers = nn.ModuleList()
        for _ in range(cfg.layers):
            self.layers.append(nn.ModuleDict({
                'self_attn': nn.MultiheadAttention(cfg.width, cfg.heads, batch_first=True),
                'cross_attn': nn.MultiheadAttention(cfg.width, cfg.heads, batch_first=True),
                'ln_1': LayerNorm(cfg.width),
                'ln_2': LayerNorm(cfg.width),
                'ln_3': LayerNorm(cfg.width),
                'mlp': nn.Sequential(
                    nn.Linear(cfg.width, int(cfg.width * cfg.mlp_ratio)),
                    QuickGELU(),
                    nn.Linear(int(cfg.width * cfg.mlp_ratio), cfg.width),
                )
            }))
        
        self.ln_final = LayerNorm(cfg.width)
        self.output_proj = nn.Linear(cfg.width, vocab_size)
        
        # Causal mask
        self.register_buffer(
            "causal_mask",
            torch.triu(torch.ones(cfg.context_length, cfg.context_length), diagonal=1).bool()
        )
        
    def forward(
        self, 
        biosignal_tokens: torch.Tensor,
        text: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            biosignal_tokens: (batch_size, num_tokens, biosignal_width)
            text: (batch_size, seq_len) token ids
        Returns:
            logits: (batch_size, seq_len, vocab_size)
        """
        seq_len = text.shape[1]
        
        # Embed text tokens
        x = self.token_embedding(text)
        x = x + self.pos_embedding[:seq_len]
        
        # Project biosignal tokens
        memory = self.biosignal_proj(biosignal_tokens)
        
        # Causal mask for self-attention
        causal_mask = self.causal_mask[:seq_len, :seq_len]
        
        # Decoder layers
        for layer in self.layers:
            # Self-attention
            attn_out, _ = layer['self_attn'](x, x, x, attn_mask=causal_mask)
            x = layer['ln_1'](x + attn_out)
            
            # Cross-attention with biosignal tokens
            cross_out, _ = layer['cross_attn'](x, memory, memory)
            x = layer['ln_2'](x + cross_out)
            
            # MLP
            x = layer['ln_3'](x + layer['mlp'](x))
        
        x = self.ln_final(x)
        logits = self.output_proj(x)
        
        return logits


# ============================================================================
# Signal Reconstruction Decoder
# ============================================================================

class SignalReconstructionDecoder(nn.Module):
    """
    Lightweight transformer decoder for signal reconstruction (MAE-style).
    Reconstructs the original biosignal from encoder features.
    """
    
    def __init__(
        self,
        input_dim: int = 768,
        num_layers: int = 2,
        num_heads: int = 8,
        output_channels: int = 12,
        output_length: int = 3840,
    ):
        super().__init__()
        
        self.transformer = nn.ModuleList([
            TransformerBlock(input_dim, num_heads, mlp_ratio=2.0)
            for _ in range(num_layers)
        ])
        
        self.ln_final = LayerNorm(input_dim)
        
        # Project to signal space
        self.to_signal = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.ReLU(),
            nn.Linear(input_dim // 2, output_channels * output_length),
        )
        
        self.output_channels = output_channels
        self.output_length = output_length
    
    def forward(self, encoder_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            encoder_features: (B, seq_len, input_dim)
        Returns:
            reconstructed: (B, output_channels, output_length)
        """
        B = encoder_features.shape[0]
        
        x = encoder_features
        for layer in self.transformer:
            x = layer(x)
        
        x = self.ln_final(x)
        
        # Global average pooling
        x = x.mean(dim=1)  # (B, dim)
        
        # Project to signal space
        signal = self.to_signal(x)
        signal = signal.reshape(B, self.output_channels, self.output_length)
        
        return signal


# ============================================================================
# Full Model
# ============================================================================

class BiosignalsTextModel(nn.Module):
    """
    Complete biosignals-text model combining:
    - Biosignals encoder
    - Text encoder
    - Multimodal decoder
    - Optional signal reconstruction decoder
    """
    
    def __init__(
        self,
        embed_dim: int = 512,
        biosignals_cfg: BiosignalsCfg = None,
        text_cfg: TextCfg = None,
        decoder_cfg: DecoderCfg = None,
        use_signal_decoder: bool = True,
        init_logit_scale: float = np.log(1 / 0.07),
    ):
        super().__init__()
        
        if biosignals_cfg is None:
            biosignals_cfg = BiosignalsCfg()
        if text_cfg is None:
            text_cfg = TextCfg()
        if decoder_cfg is None:
            decoder_cfg = DecoderCfg()
            
        self.embed_dim = embed_dim
        
        # Biosignals encoder
        self.biosignals_encoder = BiosignalsEncoder(
            biosignals_cfg,
            embed_dim=embed_dim,
            output_tokens=True
        )
        
        # Text encoder
        self.text_encoder = TextEncoder(text_cfg, embed_dim=embed_dim)
        
        # Multimodal decoder for captioning
        self.text_decoder = MultimodalDecoder(
            decoder_cfg,
            vocab_size=text_cfg.vocab_size,
            biosignal_width=biosignals_cfg.transformer_width
        )
        
        # Signal reconstruction decoder
        self.use_signal_decoder = use_signal_decoder
        if use_signal_decoder:
            self.signal_decoder = SignalReconstructionDecoder(
                input_dim=biosignals_cfg.transformer_width,
                num_layers=2,
                num_heads=biosignals_cfg.transformer_heads,
                output_channels=biosignals_cfg.input_channels,
                output_length=biosignals_cfg.signal_length,
            )
        
        # Learnable temperature for contrastive loss
        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
        
    def encode_biosignals(self, biosignals: torch.Tensor, normalize: bool = True):
        """Encode biosignals to embedding space"""
        embedding, tokens = self.biosignals_encoder(biosignals)
        if normalize:
            embedding = F.normalize(embedding, dim=-1)
        return embedding, tokens
    
    def encode_text(self, text: torch.Tensor, normalize: bool = True):
        """Encode text to embedding space"""
        embedding, tokens = self.text_encoder(text)
        if normalize:
            embedding = F.normalize(embedding, dim=-1)
        return embedding, tokens
    
    def forward(
        self,
        biosignals: torch.Tensor,
        text: Optional[torch.Tensor] = None,
        output_labels: bool = True,
    ):
        """
        Forward pass for training.
        
        Args:
            biosignals: (B, channels, length) input signals
            text: (B, seq_len) token ids
            output_labels: Whether to shift text for label generation
            
        Returns:
            Dictionary with features, logits, and labels
        """
        # Encode biosignals
        biosignal_features, biosignal_tokens = self.encode_biosignals(biosignals)
        
        out = {
            "biosignal_features": biosignal_features,
            "logit_scale": self.logit_scale.exp()
        }
        
        if text is not None:
            # Encode text
            text_features, text_tokens = self.encode_text(text)
            out["text_features"] = text_features
            
            # Prepare labels
            labels = text[:, 1:] if output_labels else None
            text_input = text[:, :-1] if output_labels else text
            
            # Generate caption logits
            logits = self.text_decoder(biosignal_tokens, text_input)
            out["logits"] = logits
            
            if labels is not None:
                out["labels"] = labels
        
        # Signal reconstruction
        if self.use_signal_decoder:
            reconstructed = self.signal_decoder(biosignal_tokens)
            out["reconstructed_signal"] = reconstructed
            out["original_signal"] = biosignals
            
        return out
    
    def generate(
        self,
        biosignals: torch.Tensor,
        max_length: int = 77,
        temperature: float = 1.0,
        top_k: int = 50,
        eos_token_id: int = 2,
        bos_token_id: int = 1,
    ) -> torch.Tensor:
        """
        Generate captions for biosignals.
        
        Args:
            biosignals: (B, channels, length) input signals
            max_length: Maximum sequence length
            temperature: Sampling temperature
            top_k: Top-k sampling
            eos_token_id: End of sequence token
            bos_token_id: Beginning of sequence token
            
        Returns:
            generated: (B, max_length) generated token ids
        """
        device = biosignals.device
        batch_size = biosignals.shape[0]
        
        # Encode biosignals
        _, biosignal_tokens = self.encode_biosignals(biosignals)
        
        # Initialize with BOS token
        generated = torch.full(
            (batch_size, 1), bos_token_id, dtype=torch.long, device=device
        )
        
        for _ in range(max_length - 1):
            # Get logits for next token
            logits = self.text_decoder(biosignal_tokens, generated)
            next_logits = logits[:, -1, :] / temperature
            
            # Top-k sampling
            if top_k > 0:
                values, indices = torch.topk(next_logits, top_k)
                next_logits = torch.full_like(next_logits, float('-inf'))
                next_logits.scatter_(1, indices, values)
            
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            
            generated = torch.cat([generated, next_token], dim=1)
            
            # Check for EOS
            if (next_token == eos_token_id).all():
                break
                
        return generated


# ============================================================================
# Factory Function
# ============================================================================

def create_model(
    input_channels: int = 12,
    signal_length: int = 3840,
    embed_dim: int = 512,
    vocab_size: int = 50257,
    use_signal_decoder: bool = True,
) -> BiosignalsTextModel:
    """
    Create a biosignals-text model with default configuration.
    
    Args:
        input_channels: Number of biosignal channels
        signal_length: Length of input signals
        embed_dim: Embedding dimension
        vocab_size: Vocabulary size
        use_signal_decoder: Whether to include signal reconstruction decoder
        
    Returns:
        Configured BiosignalsTextModel
    """
    biosignals_cfg = BiosignalsCfg(
        input_channels=input_channels,
        signal_length=signal_length,
    )
    
    text_cfg = TextCfg(vocab_size=vocab_size)
    decoder_cfg = DecoderCfg()
    
    return BiosignalsTextModel(
        embed_dim=embed_dim,
        biosignals_cfg=biosignals_cfg,
        text_cfg=text_cfg,
        decoder_cfg=decoder_cfg,
        use_signal_decoder=use_signal_decoder,
    )

