"""
Multi-Scale Attention U-Net (MSA-UNet) model implementation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import List, Tuple, Optional

class CrossScaleAttention(nn.Module):
    """Cross-scale attention mechanism for multi-scale feature interaction"""
    
    def __init__(self, in_channels: int, num_heads: int = 4, d_k: int = 64):
        super(CrossScaleAttention, self).__init__()
        self.in_channels = in_channels
        self.num_heads = num_heads
        self.d_k = d_k
        self.d_v = d_k
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(in_channels, num_heads * d_k, bias=False)
        self.W_k = nn.Linear(in_channels, num_heads * d_k, bias=False)
        self.W_v = nn.Linear(in_channels, num_heads * d_v, bias=False)
        self.W_o = nn.Linear(num_heads * d_v, in_channels)
        
        self.scale = math.sqrt(d_k)
        
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        Args:
            query: (B, N_q, C) - query features
            key: (B, N_k, C) - key features  
            value: (B, N_v, C) - value features
        Returns:
            output: (B, N_q, C) - attended features
        """
        B, N_q, C = query.size()
        N_k = key.size(1)
        N_v = value.size(1)
        
        # Linear projections
        Q = self.W_q(query).view(B, N_q, self.num_heads, self.d_k).transpose(1, 2)  # (B, H, N_q, d_k)
        K = self.W_k(key).view(B, N_k, self.num_heads, self.d_k).transpose(1, 2)    # (B, H, N_k, d_k)
        V = self.W_v(value).view(B, N_v, self.num_heads, self.d_v).transpose(1, 2)  # (B, H, N_v, d_v)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, H, N_q, N_k)
        attn_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        attended = torch.matmul(attn_weights, V)  # (B, H, N_q, d_v)
        
        # Concatenate heads
        attended = attended.transpose(1, 2).contiguous().view(B, N_q, self.num_heads * self.d_v)
        
        # Final linear projection
        output = self.W_o(attended)
        
        return output

class ScaleSelection(nn.Module):
    """Scale selection mechanism for adaptive feature fusion"""
    
    def __init__(self, in_channels: int, num_scales: int = 4):
        super(ScaleSelection, self).__init__()
        self.num_scales = num_scales
        
        # Shared feature transformation
        self.feature_transform = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 4, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 4, in_channels // 4, 1)
        )
        
        # Scale-specific attention weights
        self.scale_weights = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels // 4, 1, 1),
                nn.Sigmoid()
            ) for _ in range(num_scales)
        ])
        
    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        """
        Args:
            features: List of feature maps at different scales
        Returns:
            fused_features: Fused multi-scale features
        """
        # Transform features
        transformed_features = [self.feature_transform(f) for f in features]
        
        # Compute attention weights for each scale
        attention_weights = []
        for i, feat in enumerate(transformed_features):
            weight = self.scale_weights[i](feat)
            attention_weights.append(weight)
        
        # Normalize attention weights
        attention_weights = torch.stack(attention_weights, dim=1)  # (B, num_scales, 1, H, W)
        attention_weights = F.softmax(attention_weights, dim=1)
        
        # Weighted fusion
        fused_features = torch.zeros_like(transformed_features[0])
        for i, feat in enumerate(transformed_features):
            fused_features += attention_weights[:, i] * feat
            
        return fused_features

class ChannelAttention(nn.Module):
    """Channel attention module for feature recalibration"""
    
    def __init__(self, in_channels: int, reduction: int = 16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class EncoderBlock(nn.Module):
    """Encoder block with channel attention"""
    
    def __init__(self, in_channels: int, out_channels: int, use_attention: bool = True):
        super(EncoderBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2, 2)
        
        self.use_attention = use_attention
        if use_attention:
            self.channel_attention = ChannelAttention(out_channels)
            
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # First convolution
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        # Second convolution
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        # Channel attention
        if self.use_attention:
            attention = self.channel_attention(out)
            out = out * attention
        
        # Save for skip connection
        skip = out
        
        # Pooling
        out = self.pool(out)
        
        return out, skip

class DecoderBlock(nn.Module):
    """Decoder block with skip connections"""
    
    def __init__(self, in_channels: int, out_channels: int, skip_channels: int = 0):
        super(DecoderBlock, self).__init__()
        self.upsample = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, 2)
        self.conv1 = nn.Conv2d(in_channels // 2 + skip_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Upsample
        x = self.upsample(x)
        
        # Concatenate with skip connection
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        
        # Convolutions
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        return x

class MSAUNet(nn.Module):
    """Multi-Scale Attention U-Net for medical image segmentation"""
    
    def __init__(self, in_channels: int = 3, num_classes: int = 5, num_heads: int = 4):
        super(MSAUNet, self).__init__()
        self.num_classes = num_classes
        self.num_heads = num_heads
        
        # Encoder
        self.encoder1 = EncoderBlock(in_channels, 64, use_attention=True)
        self.encoder2 = EncoderBlock(64, 128, use_attention=True)
        self.encoder3 = EncoderBlock(128, 256, use_attention=True)
        self.encoder4 = EncoderBlock(256, 512, use_attention=True)
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )
        
        # Cross-scale attention
        self.cross_scale_attention = CrossScaleAttention(1024, num_heads=num_heads)
        
        # Scale selection
        self.scale_selection = ScaleSelection(1024, num_scales=4)
        
        # Decoder
        self.decoder4 = DecoderBlock(1024, 512, skip_channels=512)
        self.decoder3 = DecoderBlock(512, 256, skip_channels=256)
        self.decoder2 = DecoderBlock(256, 128, skip_channels=128)
        self.decoder1 = DecoderBlock(128, 64, skip_channels=64)
        
        # Final classification
        self.final_conv = nn.Conv2d(64, num_classes, 1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        x1, skip1 = self.encoder1(x)      # 64 channels
        x2, skip2 = self.encoder2(x1)     # 128 channels
        x3, skip3 = self.encoder3(x2)     # 256 channels
        x4, skip4 = self.encoder4(x3)     # 512 channels
        
        # Bottleneck
        bottleneck = self.bottleneck(x4)  # 1024 channels
        
        # Multi-scale features for attention
        scales = [x1, x2, x3, x4]
        
        # Resize all scales to bottleneck size for attention
        B, C, H, W = bottleneck.size()
        scaled_features = []
        for scale in scales:
            scaled = F.interpolate(scale, size=(H, W), mode='bilinear', align_corners=False)
            scaled_features.append(scaled)
        
        # Cross-scale attention
        # Flatten spatial dimensions for attention
        bottleneck_flat = bottleneck.view(B, C, -1).transpose(1, 2)  # (B, H*W, C)
        scaled_features_flat = [f.view(B, f.size(1), -1).transpose(1, 2) for f in scaled_features]
        
        # Apply attention between bottleneck and each scale
        attended_features = []
        for scaled_feat in scaled_features_flat:
            attended = self.cross_scale_attention(bottleneck_flat, scaled_feat, scaled_feat)
            attended_features.append(attended)
        
        # Reshape back to spatial dimensions
        attended_features = [f.transpose(1, 2).view(B, C, H, W) for f in attended_features]
        
        # Scale selection and fusion
        fused_features = self.scale_selection(attended_features)
        
        # Decoder with skip connections
        d4 = self.decoder4(fused_features, skip4)
        d3 = self.decoder3(d4, skip3)
        d2 = self.decoder2(d3, skip2)
        d1 = self.decoder1(d2, skip1)
        
        # Final classification
        output = self.final_conv(d1)
        
        return output

class BaselineUNet(nn.Module):
    """Baseline U-Net for comparison"""
    
    def __init__(self, in_channels: int = 3, num_classes: int = 5):
        super(BaselineUNet, self).__init__()
        
        # Encoder
        self.encoder1 = EncoderBlock(in_channels, 64, use_attention=False)
        self.encoder2 = EncoderBlock(64, 128, use_attention=False)
        self.encoder3 = EncoderBlock(128, 256, use_attention=False)
        self.encoder4 = EncoderBlock(256, 512, use_attention=False)
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )
        
        # Decoder
        self.decoder4 = DecoderBlock(1024, 512, skip_channels=512)
        self.decoder3 = DecoderBlock(512, 256, skip_channels=256)
        self.decoder2 = DecoderBlock(256, 128, skip_channels=128)
        self.decoder1 = DecoderBlock(128, 64, skip_channels=64)
        
        # Final classification
        self.final_conv = nn.Conv2d(64, num_classes, 1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        x1, skip1 = self.encoder1(x)
        x2, skip2 = self.encoder2(x1)
        x3, skip3 = self.encoder3(x2)
        x4, skip4 = self.encoder4(x3)
        
        # Bottleneck
        bottleneck = self.bottleneck(x4)
        
        # Decoder
        d4 = self.decoder4(bottleneck, skip4)
        d3 = self.decoder3(d4, skip3)
        d2 = self.decoder2(d3, skip2)
        d1 = self.decoder1(d2, skip1)
        
        # Final classification
        output = self.final_conv(d1)
        
        return output

def count_parameters(model: nn.Module) -> int:
    """Count the number of trainable parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

if __name__ == "__main__":
    # Test model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create models
    msa_unet = MSAUNet(in_channels=3, num_classes=5, num_heads=4)
    baseline_unet = BaselineUNet(in_channels=3, num_classes=5)
    
    # Test input
    batch_size = 2
    input_tensor = torch.randn(batch_size, 3, 512, 512)
    
    # Forward pass
    msa_output = msa_unet(input_tensor)
    baseline_output = baseline_unet(input_tensor)
    
    print(f"MSA-UNet output shape: {msa_output.shape}")
    print(f"Baseline U-Net output shape: {baseline_output.shape}")
    print(f"MSA-UNet parameters: {count_parameters(msa_unet):,}")
    print(f"Baseline U-Net parameters: {count_parameters(baseline_unet):,}")
    
    # Test on GPU if available
    if torch.cuda.is_available():
        msa_unet = msa_unet.to(device)
        input_tensor = input_tensor.to(device)
        
        with torch.no_grad():
            msa_output = msa_unet(input_tensor)
            print(f"GPU MSA-UNet output shape: {msa_output.shape}")

