import torch
from torch import nn
import math
import torch.nn.functional as F


class DisentangledTransformer(nn.Module):

    def __init__(
        self,
        num_nodes,
        heads,  # number of heads of each layer
        extra_pos_id=True,
        init_type="randn",  # "randn", "zeros", "eye", "psd" or "sym"
        readout_type="linear",  # "linear" or "sum", "last"
        final_activation="tanh",  # "tanh", "relu", "sigmoid", "clamp_to_prob", or "phi_logits", None
    ):
        super(DisentangledTransformer, self).__init__()
        self.num_nodes = num_nodes
        d = num_nodes
        if extra_pos_id:
            d *= 2

        A = []
        for n_head in heads:
            if init_type == "randn":
                Ai = nn.Parameter(torch.randn([n_head, d, d]) / math.sqrt(d))
            elif init_type == "zeros":
                Ai = nn.Parameter(torch.zeros([n_head, d, d]))
            elif init_type == "eye":
                Ai = nn.Parameter(torch.eye(d).repeat(n_head, 1, 1))
            elif init_type == "psd":
                mat = torch.randn([n_head, d, d])
                Ai = nn.Parameter(torch.einsum("bjk,blk->bjl", mat, mat) / d)
            elif init_type == "sym":
                mat = torch.randn([n_head, d, d])
                Ai = nn.Parameter(
                    (mat + torch.transpose(mat, -1, -2)) / (2 * math.sqrt(d))
                )
            A.append(Ai)
            d *= 1 + n_head

        self.readout_type = readout_type
        self.A = nn.ParameterList(A)
        self.W = nn.Linear(d, num_nodes)  # readout layer
        if readout_type in ["sum"]:
            I = torch.eye(self.num_nodes)
            num_repeats = d // num_nodes
            self.W.weight.data = torch.cat([I] * num_repeats, dim=1)
            self.W.bias.data = torch.zeros(num_nodes)
            self.W.weight.requires_grad = False
            self.W.bias.requires_grad = False

        self.hidden_states = []
        self.store_hidden_states = False  # Only enabled during evaluation, not training
        self.extra_pos_id = extra_pos_id
        self.final_activation = final_activation

    def attn(self, x, A):
        # Optimized attention computation
        # Original: attn = torch.einsum("...ij,jk,...lk -> ...il", x, A, x)
        # More efficient: compute x @ A first, then multiply with x
        xA = torch.einsum("...ij,jk->...ik", x, A)  # More cache-friendly
        attn = torch.einsum("...ij,...lj->...il", xA, x)  # Then compute attention

        # attn = torch.softmax(attn, dim=-1)
        attn = torch.relu(attn) / self.num_nodes  # More efficient than division first

        # More efficient final computation
        attn = torch.einsum("...ij,...jk->...ik", attn, x)
        return attn

    def attn_vectorized(self, x, A_heads):
        """Vectorized attention computation across all heads to replace vmap"""
        # x: [batch_size, num_nodes, d]
        # A_heads: [num_heads, d, d]

        batch_size, num_nodes, d = x.shape
        num_heads = A_heads.shape[0]

        # Expand x for all heads: [batch_size, num_heads, num_nodes, d]
        x_expanded = x.unsqueeze(1).expand(batch_size, num_heads, num_nodes, d)

        # Batch matrix multiply: [batch_size, num_heads, num_nodes, d] @ [num_heads, d, d]
        # Result: [batch_size, num_heads, num_nodes, d]
        xA = torch.einsum("bhnd,hde->bhne", x_expanded, A_heads)

        # Compute attention: [batch_size, num_heads, num_nodes, num_nodes]
        attn = torch.einsum("bhnd,bhmd->bhnm", xA, x_expanded)

        # Apply ReLU and normalization
        attn = torch.relu(attn) / self.num_nodes

        # Final computation: [batch_size, num_heads, num_nodes, d]
        result = torch.einsum("bhnm,bhmd->bhnd", attn, x_expanded)

        # Reshape to match vmap output: [batch_size, num_nodes, num_heads * d]
        return result.transpose(1, 2).reshape(batch_size, num_nodes, num_heads * d)

    def embed(self, x):
        wte = x  # [b_size, num_nodes, num_nodes]
        wpe = torch.eye(x.shape[-1]).to(x.device)  # [num_nodes, num_nodes]
        wpe = torch.broadcast_to(wpe, x.shape)  # [b_size, num_nodes, num_nodes]
        return torch.concatenate([wpe, wte], -1)  # [b_size, num_nodes, 2*num_nodes]

    def get_hidden_states(self, x):
        self.hidden_states = []
        self.store_hidden_states = True  # Enable hidden state storage
        result = self.forward(x)
        self.store_hidden_states = False  # Disable after use to prevent memory leaks
        return self.hidden_states

    # ---- NEW: φ-logits activation ----
    @staticmethod
    def _logit(p, eps=1e-6):
        p = p.clamp(eps, 1 - eps)
        return torch.log(p) - torch.log(1 - p)

    def final_activation_func(self, out):
        if self.final_activation == "tanh":
            return torch.tanh(out)
        elif self.final_activation == "relu":
            return torch.relu(out)
        elif self.final_activation == "sigmoid":
            return torch.sigmoid(out)  # NOTE: then DON'T use BCEWithLogitsLoss
        elif self.final_activation == "clamp_to_prob":
            return torch.clamp(out, min=-1, max=1)
        elif self.final_activation == "phi_logits":
            z_pos = F.relu(out)  # ensure nonnegativity
            p = 1.0 - torch.exp(-z_pos)  # (0,1)
            return self._logit(p)  # logits
        else:
            return torch.clamp(out, min=-100, max=100)

    def forward(self, x: torch.Tensor):
        if self.extra_pos_id:
            x = self.embed(x)
        if self.store_hidden_states:
            # Save attention maps for visualization at inference time (only when explicitly enabled)
            self.hidden_states.append(x.detach().cpu().numpy())

        for Ai in self.A:
            # Use vectorized attention for better performance
            attn = self.attn_vectorized(x, Ai)

            x = torch.concatenate([x, attn], -1)
            if self.store_hidden_states:
                # Save attention maps for visualization at inference time (only when explicitly enabled)
                self.hidden_states.append(x.detach().cpu().numpy())
        # shape of x: [b_size, num_nodes, d]
        if self.readout_type == "linear":
            out = self.W(x)
        elif self.readout_type == "last":
            out = x[:, :, -self.num_nodes :]
        elif self.readout_type == "sum":
            out = self.W(x)
        out = self.final_activation_func(out)
        return out
