"""
U-Net architecture for image-based jump-diffusion model.

This network takes a stack of images (current noisy image + memory frames) and
time embeddings as input, and outputs three images corresponding to the
lambda, mean, and sigma parameters for every pixel.

The time conditioning is now handled by concatenating the embeddings for both
the current time (t) and the future time (t2), allowing the network to learn
the optimal way to combine this information.
"""
import torch
import torch.nn as nn
import math

class SinusoidalPositionEmbeddings(nn.Module):
    """Encodes the diffusion timestep into a learnable vector."""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time.float()[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    """A residual block with group normalization and SiLU activation."""
    def __init__(self, in_channels, out_channels, groups=8):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm = nn.GroupNorm(groups, out_channels)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.proj(x)
        x = self.norm(x)
        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    """A residual block with time embeddings."""
    def __init__(self, in_channels, out_channels, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels))
            if time_emb_dim is not None
            else None
        )
        self.block1 = Block(in_channels, out_channels, groups=groups)
        self.block2 = Block(out_channels, out_channels, groups=groups)
        self.res_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.block1(x)
        if self.mlp is not None and time_emb is not None:
            time_emb = self.mlp(time_emb)
            h = h + time_emb.unsqueeze(-1).unsqueeze(-1)
        h = self.block2(h)
        return h + self.res_conv(x)
        
class UNet(nn.Module):
    """A U-Net for predicting noise in an image."""
    def __init__(self, in_channels=5, out_channels=3, base_channels=32):
        super().__init__()
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        
        time_dim = base_channels * 4
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(base_channels),
            nn.Linear(base_channels, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # The combined time embedding dimension will be double the single one.
        combined_time_dim = time_dim * 2

        # Encoder
        self.down1 = ResnetBlock(base_channels, base_channels*2, time_emb_dim=combined_time_dim)
        self.down2 = ResnetBlock(base_channels*2, base_channels*4, time_emb_dim=combined_time_dim)
        self.pool = nn.MaxPool2d(2)

        # Middle
        self.mid = ResnetBlock(base_channels*4, base_channels*4, time_emb_dim=combined_time_dim)

        # Decoder
        self.up1 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 2, stride=2)
        self.up_res1 = ResnetBlock(base_channels*4 + base_channels*2, base_channels*2, time_emb_dim=combined_time_dim)
        self.up2 = nn.ConvTranspose2d(base_channels*2, base_channels, 2, stride=2)
        self.up_res2 = ResnetBlock(base_channels*2 + base_channels, base_channels, time_emb_dim=combined_time_dim)

        self.out_conv = nn.Conv2d(base_channels, out_channels, 1)

    def forward(self, x, time):
        x = self.init_conv(x)
        
        # Process the time tensor (batch_size * 2, time_dim)
        t_emb_full = self.time_mlp(time)
        
        # Split into t2 and t embeddings and concatenate along the feature dimension
        t2_emb, t_emb = t_emb_full.chunk(2, dim=0)
        combined_t_emb = torch.cat([t2_emb, t_emb], dim=1)

        # Downsample
        x1 = self.down1(x, combined_t_emb)
        p1 = self.pool(x1)
        x2 = self.down2(p1, combined_t_emb)
        p2 = self.pool(x2)

        # Middle
        m = self.mid(p2, combined_t_emb)

        # Upsample
        u1 = self.up1(m)
        u1 = torch.cat([u1, x2], dim=1)
        u1 = self.up_res1(u1, combined_t_emb)
        u2 = self.up2(u1)
        u2 = torch.cat([u2, x1], dim=1)
        u2 = self.up_res2(u2, combined_t_emb)

        return self.out_conv(u2)

