import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from TLAD.sdk.energy import LagrangianFreeEnergy

class TLADModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cfg = config
        
        # --- System 1: Perception Module ---
        self.embedding = nn.Linear(10, config.HIDDEN_DIM)
        self.pos_embed = nn.Parameter(torch.randn(1, config.SEQ_LEN, config.HIDDEN_DIM))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.HIDDEN_DIM, nhead=config.NUM_HEADS, batch_first=True,
            dim_feedforward=512, dropout=0.1
        )
        self.perception = nn.TransformerEncoder(encoder_layer, num_layers=config.NUM_LAYERS_S1)
        
        self.value_head = nn.Linear(config.HIDDEN_DIM, config.HIDDEN_DIM)
        self.logit_head = nn.Linear(config.HIDDEN_DIM, 9)
        
        # --- System 2: Thermodynamic Engine ---
        self.free_energy = LagrangianFreeEnergy(config)
        self.register_buffer('hard_mask', self._create_disentangled_mask())

    def _create_disentangled_mask(self):
        """Disentangled mask: 4 modes (row/col/box/global)"""
        masks = torch.full((self.cfg.NUM_HEADS, 81, 81), float('-inf'))
        for h in range(self.cfg.NUM_HEADS):
            mode = h % 4  
            for i in range(81):
                r1, c1 = i//9, i%9
                for j in range(81):
                    if i == j:
                        masks[h, i, j] = 0.0
                        continue
                    r2, c2 = j//9, j%9
                    is_connected = False
                    if mode == 0: is_connected = (r1 == r2)
                    elif mode == 1: is_connected = (c1 == c2)
                    elif mode == 2: is_connected = (r1//3 == r2//3 and c1//3 == c2//3)
                    else: is_connected = True
                    if is_connected: masks[h, i, j] = 0.0
        return masks

    def _get_continuous_temps(self):
        steps = torch.arange(self.cfg.EBA_STEPS).to(self.cfg.DEVICE)
        N = self.cfg.EBA_STEPS   
        progress = steps / N
        logits = -(progress - self.cfg.TRANSITION_CENTER) * self.cfg.ANNEALING_SLPOE  
        temps = torch.sigmoid(logits)
        temps = self.cfg.TEMP_START * (1 - temps) + self.cfg.TEMP_END * temps
        return temps
    
    def _compute_constraint_error(self, probs):
        with torch.no_grad():
            if probs.dim() == 4 and probs.shape[1] == 81: # [B, 81, 9] -> [B, 9, 9, 9]
                probs_grid = probs.view(probs.shape[0], 9, 9, 9)
            elif probs.dim() == 4:
                probs_grid = probs
            elif probs.dim() == 3: # [B, 81, 9]
                probs_grid = probs.view(probs.shape[0], 9, 9, 9)
            else:
                return 0
                
            pred_grid = probs_grid[0].argmax(-1).cpu().numpy() + 1
            total_err = 0
            for i in range(9): total_err += 9 - len(np.unique(pred_grid[i, :]))
            for j in range(9): total_err += 9 - len(np.unique(pred_grid[:, j]))
            for bi in range(3):
                for bj in range(3):
                    box = pred_grid[bi*3:(bi+1)*3, bj*3:(bj+1)*3]
                    total_err += 9 - len(np.unique(box))
            return total_err

    def forward(self, x, mode='reasoning', return_step_wise=False, compute_metrics=False):
        B = x.shape[0]
        
        # === Phase 1: Perception ===
        x_onehot = F.one_hot(x.long(), num_classes=10).float()
        h = self.embedding(x_onehot) + self.pos_embed
        h_semantic = self.perception(h)
        
        if mode == 'pretrain':
            return self.logit_head(h_semantic)

        # === Phase 2: Thermodynamic Reasoning ===
        V = self.value_head(h_semantic).view(B, 81, self.cfg.NUM_HEADS, -1).transpose(1, 2)
        Q = h_semantic.view(B, 81, self.cfg.NUM_HEADS, -1).transpose(1, 2)
        K = h_semantic.view(B, 81, self.cfg.NUM_HEADS, -1).transpose(1, 2)
        
        A_curr = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.size(-1))
        velocity = torch.zeros_like(A_curr)
        temps = self._get_continuous_temps()
        
        outputs = {
            'probs': [], 'energy': [], 'grad_norm': [], 
            'A_final': None, 'A_0': A_curr.detach().clone()
        }
        
        if return_step_wise:
            outputs.update({
                'step_energy': [], 'step_grad_norms': [], 
                'step_probs': [], 'step_A': [] 
            })
        
        if compute_metrics: outputs['constraint_errors'] = []

        with torch.enable_grad():
            for t in range(self.cfg.EBA_STEPS):
                A_curr.requires_grad_(True)
                self.free_energy.T = temps[t]
                
                # Attention Calculation
                scores = (A_curr + self.hard_mask) / temps[t]
                attn_probs = F.softmax(scores, dim=-1) # [B, Heads, 81, 81]
                
                context = torch.matmul(attn_probs, V).transpose(1, 2).reshape(B, 81, -1)
                logits = self.logit_head(h_semantic + F.relu(context))
                probs = F.softmax(logits, dim=-1)
                
                F_total, conflict_energy = self.free_energy.compute(probs)
                
                # Backprop for A
                grads = torch.autograd.grad(F_total, A_curr, create_graph=self.training)[0]
                grad_norm = torch.norm(grads).item()
                
                # Head-specific scaling
                for h_idx in range(self.cfg.NUM_HEADS):
                    if h_idx == 0: grads[:, h_idx] *= self.free_energy.lambda_row.mean()
                    elif h_idx == 1: grads[:, h_idx] *= self.free_energy.lambda_col.mean()
                    elif h_idx == 2: grads[:, h_idx] *= self.free_energy.lambda_box.mean()

                # Dynamics Update
                velocity = self.cfg.EBA_MOMENTUM * velocity + grads
                A_curr = A_curr - self.cfg.EBA_STEP_SIZE * velocity
                
                outputs['energy'].append(conflict_energy.item())
                outputs['grad_norm'].append(grad_norm)
                outputs['probs'].append(probs.detach())
                
                if compute_metrics:
                    outputs['constraint_errors'].append(self._compute_constraint_error(probs))
                
                if return_step_wise:
                    outputs['step_energy'].append(conflict_energy.item())
                    outputs['step_grad_norms'].append(grad_norm)
                    outputs['step_probs'].append(probs.detach().clone())
                    outputs['step_A'].append(attn_probs.detach().clone())
        
        # === Final Output Calculation ===
        final_scores = (A_curr + self.hard_mask) / temps[-1]
        final_attn_probs = F.softmax(final_scores, dim=-1)
        final_context = torch.matmul(final_attn_probs, V).transpose(1, 2).reshape(B, 81, -1)
        final_logits = self.logit_head(h_semantic + F.relu(final_context))

        outputs['logits'] = final_logits
        outputs['A_final'] = A_curr
        
        if compute_metrics and outputs['probs']:
             outputs['final_constraint_error'] = self._compute_constraint_error(outputs['probs'][-1])

        return outputs