import torch
import torch.nn as nn
import torch.nn.functional as F
from TLAD.maze.utils import get_valid_path_region 


class LagrangianFreeEnergy(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cfg = config
        
        self.lambda_wall = nn.Parameter(torch.tensor(float(config.LAMBDA_INIT), device=config.DEVICE))
        self.lambda_invalid = nn.Parameter(torch.tensor(float(config.LAMBDA_INIT), device=config.DEVICE))
        self.lambda_degree = nn.Parameter(torch.tensor(float(config.LAMBDA_INIT), device=config.DEVICE))
        self.lambda_end = nn.Parameter(torch.tensor(float(config.LAMBDA_INIT), device=config.DEVICE))
        
        self.register_buffer('kernel', torch.tensor([[[[0, 1, 0], 
                                                       [1, 0, 1], 
                                                       [0, 1, 0]]]], dtype=torch.float32))
        self.lit_threshold = getattr(config, 'LIT_THRESHOLD', 0.6)

    def compute(self, probs, inputs):
        """
        Compute structured energy with hard constraints derived from input topology.
        All energy terms are normalized by L = N*N to ensure scale invariance across grid sizes.
        """
        B, L, C = probs.shape
        N = self.cfg.GRID_SIZE
        assert L == N * N
        L_float = float(L)  # for normalization

        p_path = probs[..., 1]                              # (B, L)
        p_grid = p_path.view(B, 1, N, N)                    # (B, 1, N, N)

        l_wall = torch.clamp(self.lambda_wall, 0.0, self.cfg.LAMBDA_MAX)
        l_invalid = torch.clamp(self.lambda_invalid, 0.0, self.cfg.LAMBDA_MAX)
        l_degree = torch.clamp(self.lambda_degree, 0.0, self.cfg.LAMBDA_MAX)
        l_end = torch.clamp(self.lambda_end, 0.0, self.cfg.LAMBDA_MAX)

        # === 1. Wall Repulsion ===
        is_wall = (inputs == 1).float()                     # (B, L)
        wall_energy = torch.sum(p_path * is_wall) / (B * L_float)

        # === 2. Endpoint Anchoring ===
        is_special = ((inputs == 2) | (inputs == 3))        # (B, L), bool
        end_energy = torch.sum((1.0 - p_path) ** 2 * is_special.float()) / (B * L_float)

        # === 3. Valid Path Region (HARD, from input only) ===
        with torch.no_grad():
            valid_region = get_valid_path_region(inputs, self.cfg.GRID_SIZE)  # (B, L), bool

        # === 4. Invalid Activation Penalty ===
        is_lit = (p_path > self.lit_threshold)              # (B, L), bool
        invalid_mask = is_lit & (~valid_region) & (~is_special)
        invalid_energy = torch.sum(p_path * invalid_mask.float()) / (B * L_float)

        # === 5. Degree Constraint (only on valid path) ===
        degree = F.conv2d(p_grid, self.kernel, padding=1).view(B, L)   # (B, L)
        target_deg = torch.zeros_like(degree)

        target_deg[is_special] = 1.0                        # endpoints: degree 1
        internal_path = valid_region & (~is_special)
        target_deg[internal_path] = 2.0                     # internal: degree 2

        degree_err = (degree - target_deg) ** 2
        degree_energy = torch.sum(degree_err * valid_region.float()) / (B * L_float)

        # === Combine Energy Terms (all already normalized per sample and per cell) ===
        components = {
            "wall": l_wall * self.cfg.W_WALL * wall_energy,
            "invalid": l_invalid * self.cfg.W_INVALID * invalid_energy,
            "degree": l_degree * self.cfg.W_DEGREE * degree_energy,
            "end": l_end * self.cfg.W_ENDPOINT * end_energy,
        }
        total_energy = sum(components.values())
        components["total"] = total_energy
        return total_energy, components