import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from TLAD.maze.energy import LagrangianFreeEnergy


class TLADModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cfg = config
        
        assert config.NUM_HEADS % 4 == 0 and config.NUM_HEADS > 0, f"NUM_HEADS must be a positive multiple of 4, got {config.NUM_HEADS}"
        self.num_heads = config.NUM_HEADS

        self.hidden_dim = config.HIDDEN_DIM
        self.seq_len = config.SEQ_LEN          
        self.grid_size = config.GRID_SIZE      
        self.input_classes = config.INPUT_CLASSES  

        self.stem = nn.Sequential(
            nn.Conv2d(self.input_classes, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, self.hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.hidden_dim),
            nn.ReLU()
        )
        self.row_embed = nn.Parameter(torch.randn(self.grid_size, self.hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.randn(self.grid_size, self.hidden_dim // 2))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.hidden_dim,
            nhead=self.num_heads,
            batch_first=True,
            dim_feedforward=self.hidden_dim * 4,
            dropout=0.1
        )
        self.perception = nn.TransformerEncoder(encoder_layer, num_layers=config.NUM_LAYERS_S1)

        self.value_head = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.logit_head = nn.Linear(self.hidden_dim, config.OUTPUT_CLASSES) 

        self.free_energy = LagrangianFreeEnergy(config)
        self.register_buffer('base_topo_mask', self._create_custom_202_mask())

    def _get_2d_learned_pos_embed(self, B, H, W):
        row_emb = self.row_embed[:H]
        col_emb = self.col_embed[:W]

        row_expanded = row_emb.unsqueeze(1).repeat(1, W, 1)
        col_expanded = col_emb.unsqueeze(0).repeat(H, 1, 1)

        pe = torch.cat([row_expanded, col_expanded], dim=-1)
        pe = pe.flatten(0, 1).unsqueeze(0)
        pe = pe.expand(B, -1, -1)
        return pe

    def _create_custom_202_mask(self):
        N = self.grid_size
        L = self.seq_len
        masks = torch.full((self.num_heads, L, L), float('-inf'), device=self.cfg.DEVICE)
        num_local_heads = self.num_heads // 2
        for h in range(self.num_heads):
            for i in range(L):
                r1, c1 = i // N, i % N
                masks[h, i, i] = 0.0
                for j in range(L):
                    r2, c2 = j // N, j % N
                    dist = abs(r1 - r2) + abs(c1 - c2)
                    if h < num_local_heads and dist == 1:
                        masks[h, i, j] = 0.0
                    elif h >= num_local_heads:
                        masks[h, i, j] = 0.0
        return masks

    def _get_probs_from_A(self, A_curr, dyn_mask, temp, h_semantic, V):
        B, L = h_semantic.shape[0], h_semantic.shape[1]
        scores = (A_curr + dyn_mask) / temp
        attn_probs = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_probs, V).transpose(1, 2).reshape(B, L, -1)
        logits = self.logit_head(h_semantic + context)
        return logits

    def _get_continuous_temps(self):
        steps = torch.arange(self.cfg.EBA_STEPS).to(self.cfg.DEVICE)
        progress = steps / self.cfg.EBA_STEPS
        logits = -(progress - self.cfg.TRANSITION_CENTER) * self.cfg.ANNEALING_SLOPE
        temps = torch.sigmoid(logits)
        temps = self.cfg.TEMP_START * (1 - temps) + self.cfg.TEMP_END * temps
        return temps

    def forward(self, x, mode='reasoning'):
        B, L = x.shape
        H, W = self.grid_size, self.grid_size

        x_onehot = F.one_hot(x.long(), num_classes=self.input_classes).float().view(B, H, W, -1).permute(0, 3, 1, 2)
        h = self.stem(x_onehot).flatten(2).transpose(1, 2)

        pos_embed = self._get_2d_learned_pos_embed(B, H, W)
        h = h + pos_embed
        h_semantic = self.perception(h)

        if mode == 'pretrain':
            return self.logit_head(h_semantic)

        dyn_mask = self.base_topo_mask.unsqueeze(0)
        
        V = self.value_head(h_semantic).view(B, L, self.num_heads, -1).transpose(1, 2)
        Q = h_semantic.view(B, L, self.num_heads, -1).transpose(1, 2)
        K = h_semantic.view(B, L, self.num_heads, -1).transpose(1, 2)
        
        A_curr = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.size(-1))
        velocity = torch.zeros_like(A_curr, device=self.cfg.DEVICE)
        temps = self._get_continuous_temps()
        
        outputs = {'probs': [], 'energy': [], 'grad_norm': [], 'A_final': None, 'logits': None}

        with torch.enable_grad():
            for t in range(self.cfg.EBA_STEPS):
                A_curr.requires_grad_(True)
                self.free_energy.T = temps[t]
                
                logits = self._get_probs_from_A(A_curr, dyn_mask, temps[t], h_semantic, V)
                probs = F.softmax(logits, dim=-1)

                F_total, energy_components = self.free_energy.compute(probs, x)
                grads = torch.autograd.grad(F_total, A_curr, retain_graph=True)[0]

                num_local_heads = self.num_heads // 2
                grads[:, :num_local_heads] *= self.free_energy.lambda_degree
                grads[:, num_local_heads:] *= self.free_energy.lambda_end

                velocity = self.cfg.EBA_MOMENTUM * velocity + grads
                A_curr = A_curr - self.cfg.EBA_STEP_SIZE * velocity
                
                outputs['energy'].append(energy_components['total'].item())

        final_logits = self._get_probs_from_A(A_curr, dyn_mask, temps[-1], h_semantic, V)
        outputs['logits'] = final_logits
        return outputs