import torch
import torch.nn.functional as F
from torch import nn


class ConvDecoderAE(nn.Module):
    """Two-level hourglass with input skip for spatial residual refinement."""

    def __init__(self, hidden_dim):
        super().__init__()
        self.pool = nn.MaxPool2d(2)

        self.enc1 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(hidden_dim // 2, hidden_dim // 4, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        self.up2_conv = nn.Sequential(
            nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.up1_conv = nn.Sequential(
            nn.Conv2d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        self.dec2 = nn.Sequential(
            nn.Conv2d((hidden_dim // 2) * 2, hidden_dim // 2, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.dec1 = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        self.fuse_input = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=1),
            nn.ReLU(),
        )
        self.final = nn.Conv2d(hidden_dim, 3, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool(e1)

        e2 = self.enc2(p1)
        p2 = self.pool(e2)

        bottleneck = self.enc3(p2)

        u2 = F.interpolate(bottleneck, size=e2.shape[-2:], mode="nearest")
        u2 = self.up2_conv(u2)
        d2 = self.dec2(torch.cat([u2, e2], dim=1))

        u1 = F.interpolate(d2, size=e1.shape[-2:], mode="nearest")
        u1 = self.up1_conv(u1)
        d1 = self.dec1(torch.cat([u1, e1], dim=1))

        fused = self.fuse_input(torch.cat([d1, x], dim=1))
        return self.final(fused)


class ColorFusionResidualNet(nn.Module):
    """Aggregates per-view features into a single residual correction.
    
    [Simplified] Only supports 'iterative' mode.
    """

    def __init__(
        self,
        *,
        height,
        width,
        per_view_feat_dim=32,
        max_offset=7.0, 
        n_offset_iters=5,  # number of iterations
    ):
        super().__init__()
        self.height = height
        self.width = width
        self.per_view_feat_dim = per_view_feat_dim
        self.max_offset = max_offset
        self.n_offset_iters = n_offset_iters

        # Feature encoder
        self.per_view_mlp = nn.Sequential(
            nn.Linear(7, per_view_feat_dim),
            nn.ReLU(),
            nn.Linear(per_view_feat_dim, per_view_feat_dim),
            nn.ReLU(),
        )

        # Offset prediction: WHERE to sample (Δx, Δy per view)
        self.offset_net = nn.Sequential(
            nn.Linear(per_view_feat_dim, per_view_feat_dim),
            nn.ReLU(),
            nn.Linear(per_view_feat_dim, 2),  # Δx, Δy
        )
        
        # Weight prediction: HOW MUCH to weight
        self.weight_net = nn.Sequential(
            nn.Linear(per_view_feat_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
        )

        cnn_input_dim = per_view_feat_dim + 6  # aggregated features + ray dir + gaussian color
        self.conv_decoder = ConvDecoderAE(cnn_input_dim)

    def forward(
        self,
        x_views,
        ray_dir,
        c_3dgs,
        warped_images, # Now mandatory
    ):
        """Return per-pixel residual in RGB space.

        Args:
            x_views: (H*W, M, 7) stacked per-view features.
            ray_dir: (H*W, 3) normalized ray directions per pixel.
            c_3dgs: (H*W, 3) rendered Gaussian color per pixel.
            warped_images: (H, W, M, 3) actual warped images.
        """
        if warped_images is None:
            raise ValueError("Iterative mode requires warped_images.")

        B, M, _ = x_views.shape  # B = H*W
        H, W = self.height, self.width
        device = x_views.device

        # Encode per-view features
        features = self.per_view_mlp(x_views.view(B * M, -1)).view(B, M, -1)  # (H*W, M, 32)

        # Create base grid
        grid_y, grid_x = torch.meshgrid(
            torch.linspace(-1, 1, H, device=device),
            torch.linspace(-1, 1, W, device=device),
            indexing='ij'
        )
        base_grid = torch.stack([grid_x, grid_y], dim=-1)  # (H, W, 2)

        # === Iterative deformable logic ===
        rendered_flat = c_3dgs  # (H*W, 3)

        # Initialize offset to zero
        offsets = torch.zeros(B, M, 2, device=device)  # (H*W, M, 2)

        for _ in range(self.n_offset_iters):
            # Predict delta offset from current features
            delta = self.offset_net(features)  # (H*W, M, 2)
            
            # Soft-clamp delta step
            delta = torch.tanh(delta) * self.max_offset

            # Accumulate offset
            offsets = offsets + delta
            
            # [Soft-clamp] Global offset constraint using tanh
            offsets = torch.tanh(offsets / self.max_offset) * self.max_offset

            # Sample colors at current offset positions
            offsets_2d = offsets.view(H, W, M, 2)
            offsets_normalized = offsets_2d.clone()
            offsets_normalized[..., 0] = offsets_normalized[..., 0] / (W / 2)
            offsets_normalized[..., 1] = offsets_normalized[..., 1] / (H / 2)

            sampled_colors = []
            for m in range(M):
                warped_m = warped_images[:, :, m, :].permute(2, 0, 1).unsqueeze(0)
                sample_grid = (base_grid + offsets_normalized[:, :, m, :]).unsqueeze(0)
                sampled = F.grid_sample(
                    warped_m, sample_grid,
                    mode='bilinear', padding_mode='border', align_corners=True
                )
                sampled_colors.append(sampled)

            sampled_colors = torch.cat(sampled_colors, dim=0)
            sampled_colors = sampled_colors.permute(2, 3, 0, 1).reshape(B, M, 3)

            # Recompute residuals and re-encode features
            sampled_residuals = sampled_colors - rendered_flat.unsqueeze(1)
            x_views_updated = x_views.clone()
            x_views_updated[:, :, :3] = sampled_residuals
            features = self.per_view_mlp(x_views_updated.view(B * M, -1)).view(B, M, -1)

        # Predict weights from final features
        weights = self.weight_net(features).squeeze(-1)
        weights = F.softmax(weights, dim=1)

        # Weighted aggregation
        aggregated = (weights.unsqueeze(-1) * features).sum(dim=1)

        # Spatial refinement
        feat_grid = aggregated.T.view(1, self.per_view_feat_dim, self.height, self.width)
        ray_grid = ray_dir.T.view(1, 3, self.height, self.width)
        color_grid = c_3dgs.T.view(1, 3, self.height, self.width)

        cnn_input = torch.cat([feat_grid, ray_grid, color_grid], dim=1)
        residual = self.conv_decoder(cnn_input)
        return residual.permute(2, 3, 0, 1).view(B, 3)



def create_color_aggregation_network(
    height,
    width,
    **kwargs
):
    """
    Factory function to create color aggregation network (Iterative Only).
    """
    return ColorFusionResidualNet(
        height=height,
        width=width,
        per_view_feat_dim=kwargs.get("per_view_feat_dim", 32),
        max_offset=kwargs.get("max_offset", 7.0),
        n_offset_iters=kwargs.get("n_offset_iters", 5),
    )