# FILE: unet_model.py
# VERSION: v3 (With FiLM layer and SE-Block for intelligent feature fusion)

import torch
import torch.nn as nn
import torch.nn.functional as F

class SEBlock(nn.Module):
    """
    Squeeze-and-Excitation Block: a channel-wise attention mechanism.
    It adaptively re-calibrates channel-wise feature responses.
    """
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        # Squeeze operation: Global Average Pooling
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # Excitation operation: a small MLP
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        # Squeeze
        y = self.avg_pool(x).view(b, c)
        # Excitation
        y = self.fc(y).view(b, c, 1, 1)
        # Scale the input x
        return x * y.expand_as(x)

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2 + SEBlock"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            # Add SEBlock here to recalibrate channels after convolutions
            SEBlock(out_channels)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

class FiLMLayer(nn.Module):
    """
    Feature-wise Linear Modulation Layer.
    It uses a conditioning vector to generate affine transformation parameters (gamma, beta)
    which are then applied to the main feature map.
    """
    def __init__(self, channels, condition_dim):
        super().__init__()
        self.channels = channels
        # A simple linear network to generate gamma and beta
        self.generator = nn.Sequential(
            nn.Linear(condition_dim, channels),
            nn.ReLU(inplace=True),
            nn.Linear(channels, channels * 2) # Output is 2*channels for gamma and beta
        )

    def forward(self, x, condition):
        # Generate gamma and beta from the conditioning vector (e.g., goal vector)
        gamma_beta = self.generator(condition)
        
        # Split into gamma and beta, and reshape for broadcasting
        gamma = gamma_beta[:, :self.channels].unsqueeze(-1).unsqueeze(-1)
        beta = gamma_beta[:, self.channels:].unsqueeze(-1).unsqueeze(-1)
        
        # Apply the affine transformation
        return gamma * x + beta

class UNetPotentialField(nn.Module):
    """
    An intelligent U-Net that uses FiLM for non-spatial feature fusion and SE-Blocks
    for channel-wise attention. The model signature remains the same, requiring no
    changes to the data pipeline.
    """
    def __init__(self, n_spatial_channels, n_non_spatial_features=2, n_classes_out=1, bilinear_upsample=True, init_channel=64):
        super(UNetPotentialField, self).__init__()
        
        self.n_in_channels = n_spatial_channels
        self.n_classes_out = n_classes_out
        self.bilinear = bilinear_upsample

        # Encoder path
        self.inc = DoubleConv(self.n_in_channels, init_channel)
        self.down1 = Down(init_channel, init_channel * 2)
        self.down2 = Down(init_channel * 2, init_channel * 4)
        factor = 2 if bilinear_upsample else 1
        self.down3 = Down(init_channel * 4, init_channel * 8 // factor)
        
        # --- FiLM Layer for modulating the bottleneck based on non-spatial features ---
        bottleneck_channels = init_channel * 8 // factor
        self.film_layer = FiLMLayer(bottleneck_channels, n_non_spatial_features)

        # Decoder path
        self.up1 = Up(init_channel * 8, init_channel * 4 // factor, bilinear_upsample)
        self.up2 = Up(init_channel * 4, init_channel * 2 // factor, bilinear_upsample)
        self.up3 = Up(init_channel * 2, init_channel, bilinear_upsample)
        self.outc = OutConv(init_channel, n_classes_out)

    def forward(self, x_spatial: torch.Tensor, x_non_spatial: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the enhanced U-Net.
        
        Args:
            x_spatial (torch.Tensor): Spatial feature maps. 
                                      Shape (B, C_spatial, H, W).
            x_non_spatial (torch.Tensor): Non-spatial conditioning vector (e.g., goal vector). 
                                          Shape (B, C_non_spatial).
            
        Returns:
            torch.Tensor: The predicted potential field. Shape (B, 1, H, W).
        """
        # Encoder path first processes ONLY the spatial information
        x1 = self.inc(x_spatial)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3) # Bottleneck features
        
        # --- Modulate the bottleneck with non-spatial information using FiLM ---
        x4_modulated = self.film_layer(x4, x_non_spatial)
        
        # Decoder path reconstructs the potential field from the modulated features
        x = self.up1(x4_modulated, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        logits = self.outc(x)
        return logits