import torch
import torch.nn as nn

# --- Fire Attention Prior (Mixture of Gaussians) ---
class LEOAttentionMap(nn.Module):
    def __init__(self, input_dim, num_components=3):
        super(LEOAttentionMap, self).__init__()
        self.input_dim = input_dim
        self.K = num_components
        self.mu = nn.Parameter(torch.randn(self.K, input_dim))
        self.log_sigma2 = nn.Parameter(torch.zeros(self.K, input_dim))
        self.raw_weights = nn.Parameter(torch.ones(self.K))

    def forward(self, x):
        B, C, H, W = x.shape
        x_flat = x.permute(0, 2, 3, 1).reshape(-1, C)
        weights = torch.softmax(self.raw_weights, dim=0)
        attns = []
        for k in range(self.K):
            mu_k = self.mu[k].unsqueeze(0)
            sigma2_k = torch.exp(self.log_sigma2[k]).unsqueeze(0)
            dist2 = ((x_flat - mu_k) ** 2) / sigma2_k
            attn_k = torch.exp(-0.5 * dist2.sum(dim=1, keepdim=True))
            attns.append(attn_k * weights[k])
        attn = torch.stack(attns, dim=0).sum(dim=0)
        return attn.view(B, 1, H, W)

# --- Stacked LEO Attention Maps ---
class StackedLAM(nn.Module):
    def __init__(self, input_dim, num_layers=3, num_components=3):
        super(StackedLAM, self).__init__()
        self.layers = nn.ModuleList([
            LEOAttentionMap(input_dim + i, num_components) for i in range(num_layers)
        ])

    def forward(self, x):
        for fam in self.layers:
            attn = fam(x)
            x = torch.cat([x, attn], dim=1)
        return x  # return expanded x, not just last attention