import torch
import torch.nn as nn


class LagrangianFreeEnergy(nn.Module):
    """Lagrangian free energy for Sudoku hard constraints"""
    def __init__(self, config):
        super().__init__()
        self.cfg = config
        # Lagrangian multipliers
        self.lambda_col = nn.Parameter(torch.ones(9, 9).to(config.DEVICE) * self.cfg.LAMBDA_INIT)
        self.lambda_row = nn.Parameter(torch.ones(9, 9).to(config.DEVICE) * self.cfg.LAMBDA_INIT)
        self.lambda_box = nn.Parameter(torch.ones(9, 9).to(config.DEVICE) * self.cfg.LAMBDA_INIT)
        self.lambda_cell = nn.Parameter(torch.ones(1).to(config.DEVICE) * self.cfg.LAMBDA_INIT)
        self.T = 1.0  # Current temperature

    def compute(self, probs):
        B = probs.shape[0]
        grid = probs.view(B, 9, 9, 9)  # [B, Row, Col, Digit]
        
        lambda_row = torch.clamp(self.lambda_row, max=self.cfg.LAMBDA_MAX)
        lambda_col = torch.clamp(self.lambda_col, max=self.cfg.LAMBDA_MAX)
        lambda_box = torch.clamp(self.lambda_box, max=self.cfg.LAMBDA_MAX)
        lambda_cell = torch.clamp(self.lambda_cell, max=self.cfg.LAMBDA_MAX-5)

        # Cell uniqueness constraint
        cell_uniq = torch.sum(probs ** 2, dim=-1)  # [B,81]
        cell_loss = lambda_cell * (1 - cell_uniq).sum() / B
        
        # Row constraint
        row_sum = grid.sum(dim=2)  # [B,9(Row),9(Digit)]
        row_loss = torch.sum(lambda_row * (row_sum - 1.0) ** 2) / B
        
        # Column constraint
        col_sum = grid.sum(dim=1)  # [B,9(Col),9(Digit)]
        col_loss = torch.sum(lambda_col * (col_sum - 1.0) ** 2) / B
        
        # Box constraint
        box_grid = grid.view(B, 3, 3, 3, 3, 9)
        box_sum = box_grid.sum(dim=(3,4))  # [B,3,3,9]
        box_sum = box_sum.reshape(B, 9, 9)  # [B,9(Box),9(Digit)]
        box_loss = torch.sum(lambda_box * (box_sum - 1.0) ** 2) / B
        
        # Modal entropy
        max_prob = torch.max(probs, dim=-1)[0]  # [B,81]
        entropy = -torch.sum(max_prob * torch.log(max_prob + 1e-9)) / B
        
        # Free energy calculation
        total_energy = (
            self.cfg.W_CELL_UNIQ * cell_loss +
            self.cfg.W_ROW * row_loss +
            self.cfg.W_COL * col_loss +
            self.cfg.W_BOX * box_loss -
            self.T * self.cfg.W_ENTROPY * entropy
        )
        
        conflict_energy = cell_loss + row_loss + col_loss + box_loss
        return total_energy, conflict_energy