import math

import torch
import torch.nn as nn
import torch_geometric
from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch_geometric.utils import to_dense_adj, to_dense_batch

from src.constants import (
    MAX_ATOMS,
    N_ATOM_FEATURES,
    N_BOND_FEATURES,
    N_BUILDING_BLOCKS,
    N_CENTERS,
    N_REACTIONS,
)
from src.models.layers import (
    HierarchicalPositionalEmbedding,
    SimplePositionalEmbedding,
    SinusoidalPositionalEmbedding,
    TimestepEmbedding,
)
from src.models.utils import (
    PlaceHolder,
    assert_correctly_masked,
    encode_no_edge,
    masked_softmax,
    mean_pool_fragments,
    modulate,
    to_dense,
)
from src.utils.indexing_utils import (
    get_partial_atom_features,
    get_partial_bond_features,
    get_partial_maccs_keys,
)


class XEyTransformerLayer(nn.Module):
    """Transformer that updates both atom-level and fragment-level features with separate self-attention
    dx: node features dimension
    de: edge features dimension
    n_head: the number of heads in the multi_head_attention
    dim_ffX: dimension of node feedforward network
    dim_ffE: dimension of edge feedforward network
    dropout: dropout probability. 0 to disable
    layer_norm_eps: eps value in layer normalizations.
    """

    def __init__(
        self,
        dx: int,
        de: int,
        n_head: int,
        dim_ffX: int = 2048,
        dim_ffE: int = 128,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5,
        device=None,
        dtype=None,
    ) -> None:
        kw = {"device": device, "dtype": dtype}
        super().__init__()

        # Atom-level self-attention
        self.atom_self_attn = NodeEdgeBlock(dx, de, n_head, mode="atom", **kw)

        # Fragment-level self-attention
        self.frag_self_attn = NodeEdgeBlock(dx, de, n_head, mode="fragment", **kw)

        # Node feature layers
        self.linX1 = Linear(dx, dim_ffX, **kw)
        self.linXf1 = Linear(dx, dim_ffX, **kw)
        self.linX2 = Linear(dim_ffX, dx, **kw)
        self.linXf2 = Linear(dim_ffX, dx, **kw)
        self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normXf1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normXf2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.dropoutX1 = Dropout(dropout)
        self.dropoutXf1 = Dropout(dropout)
        self.dropoutX2 = Dropout(dropout)
        self.dropoutXf2 = Dropout(dropout)
        self.dropoutX3 = Dropout(dropout)
        self.dropoutXf3 = Dropout(dropout)
        self.X_atomtofrag = nn.Conv1d(
            in_channels=dx, out_channels=dx, kernel_size=MAX_ATOMS, stride=MAX_ATOMS, padding=0
        )
        self.X_frag_in = Linear(dx * 2, dx, **kw)

        # Edge feature layers
        self.linE1 = Linear(de, dim_ffE, **kw)
        self.linEf1 = Linear(de, dim_ffE, **kw)
        self.linE2 = Linear(dim_ffE, de, **kw)
        self.linEf2 = Linear(dim_ffE, de, **kw)
        self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.normEf1 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.normEf2 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.dropoutE1 = Dropout(dropout)
        self.dropoutEf1 = Dropout(dropout)
        self.dropoutE2 = Dropout(dropout)
        self.dropoutEf2 = Dropout(dropout)
        self.dropoutE3 = Dropout(dropout)
        self.dropoutEf3 = Dropout(dropout)
        self.E_atomtofrag = nn.Conv2d(
            in_channels=de, out_channels=de, kernel_size=MAX_ATOMS, stride=MAX_ATOMS, padding=0
        )
        self.E_frag_in = Linear(de * 2, de, **kw)

        # Coordinate feature layers
        self.linC1 = Linear(dx, dim_ffX, **kw)
        self.linC2 = Linear(dim_ffX, dx, **kw)
        self.normC1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normC2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.dropoutC1 = Dropout(dropout)
        self.dropoutC2 = Dropout(dropout)
        self.dropoutC3 = Dropout(dropout)

        self.activation = F.relu

    def forward(
        self,
        X: Tensor,
        X_frag: Tensor,
        X_onehot: Tensor,
        E: Tensor,
        E_frag: Tensor,
        C: Tensor,
        node_padding_mask: Tensor,
    ):
        """Pass the input through the encoder layer.
        X: Node features (bs, n*MAX_ATOMS, dx)
        X_frag: Fragment-level node features (bs, n, dx)
        X_onehot: Node onehots (bs, n, dx)
        E: Edge features (bs, n*MAX_ATOMS, n*MAX_ATOMS, de)
        E_frag: Fragment-level edge features (bs, n, n, de)
        C: Coordinate features (bs, n*MAX_ATOMS, dx)
        node_padding_mask: (bs, n*MAX_ATOMS) Mask for the src keys per batch
        Output: newX, newX_frag, newE, newE_frag, newC with the same shapes.
        """
        bs = X.shape[0]
        n = X_frag.shape[1]

        # Create fragment level padding mask
        frag_padding_mask = node_padding_mask.view(bs, n, -1).any(dim=-1)  # (bs, n)

        # 1. Atom-level self-attention with combined node-coordinate features
        newX, newE, newC = self.atom_self_attn(
            X,
            E,
            C,
            node_padding_mask=node_padding_mask,
        )

        # Node residual + norm
        newX_d = self.dropoutX1(newX)
        X = self.normX1(X + newX_d)

        # Edge residual + norm
        newE_d = self.dropoutE1(newE)
        E = self.normE1(E + newE_d)

        # Coordinate residual + norm
        newC_d = self.dropoutC1(newC)
        C = self.normC1(C + newC_d)

        newX_frag, newE_frag = self.frag_self_attn(
            X_frag,
            E_frag,
            X_onehot,
            node_padding_mask=frag_padding_mask,
        )

        # Fragment-level residual + norm
        newX_frag_d = self.dropoutXf1(newX_frag)
        X_frag = self.normXf1(X_frag + newX_frag_d)

        # Fragment-level edge residual + norm
        newE_frag_d = self.dropoutEf1(newE_frag)
        E_frag = self.normEf1(E_frag + newE_frag_d)

        # 3. Cross-granularity information exchange
        # Transform atom-level features to fragment-level
        atom_to_frag_X = X.permute(0, 2, 1)
        atom_to_frag_X = self.X_atomtofrag(atom_to_frag_X)
        atom_to_frag_X = atom_to_frag_X.permute(0, 2, 1)

        atom_to_frag_E = E.permute(0, 3, 1, 2)
        atom_to_frag_E = self.E_atomtofrag(atom_to_frag_E)
        atom_to_frag_E = atom_to_frag_E.permute(0, 2, 3, 1)

        # Combine with fragment-level features
        X_frag = torch.cat([atom_to_frag_X, X_frag], dim=-1)
        E_frag = torch.cat([atom_to_frag_E, E_frag], dim=-1)

        X_frag = self.X_frag_in(X_frag)
        E_frag = self.E_frag_in(E_frag)

        # 4. Apply feed-forward networks to both levels
        # Node-level feedforward
        ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X))))
        ff_outputX = self.dropoutX3(ff_outputX)
        X = self.normX2(X + ff_outputX)

        # Fragment-level node feedforward
        ff_outputX_frag = self.linXf2(self.dropoutXf2(self.activation(self.linXf1(X_frag))))
        ff_outputX_frag = self.dropoutXf3(ff_outputX_frag)
        X_frag = self.normXf2(X_frag + ff_outputX_frag)

        # Edge feedforward
        ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E))))
        ff_outputE = self.dropoutE3(ff_outputE)
        E = self.normE2(E + ff_outputE)

        # Fragment-level edge feedforward
        ff_outputE_frag = self.linEf2(self.dropoutEf2(self.activation(self.linEf1(E_frag))))
        ff_outputE_frag = self.dropoutEf3(ff_outputE_frag)
        E_frag = self.normEf2(E_frag + ff_outputE_frag)

        # Coordinate feedforward
        ff_outputC = self.linC2(self.dropoutC2(self.activation(self.linC1(C))))
        ff_outputC = self.dropoutC3(ff_outputC)
        C = self.normC2(C + ff_outputC)

        return X, X_frag, E, E_frag, C


class NodeEdgeBlock(nn.Module):
    """Self attention layer that updates combined node-coordinate features and edge features.
    Works with both fragment-level and atom-level granularities.
    """

    def __init__(self, dx, de, n_head, mode="fragment", **kwargs):
        super().__init__()
        assert dx * 2 % n_head == 0, f"dx: {dx} -- nhead: {n_head}"
        self.dx = dx
        self.de = de
        self.df = int(dx * 2 / n_head)
        self.n_head = n_head
        self.mode = mode

        # Self-attention for combined features
        self.q_self = Linear(dx * 2, dx * 2)
        self.k_self = Linear(dx * 2, dx * 2)
        self.v_self = Linear(dx * 2, dx * 2)
        self.e_mul_self = Linear(de, dx * 2)
        self.e_add_self = Linear(de, dx * 2)

        # Output projections
        self.combined_out = Linear(dx * 2, dx * 2)

        # Edge update from attention
        self.edge_update = nn.Sequential(
            nn.Linear(dx * 2, 2 * de), nn.ReLU(), nn.Linear(2 * de, de)
        )

        # Split the output back into node and coordinate features
        if mode == "atom":
            self.node_proj = Linear(dx, dx)
            self.coord_proj = Linear(dx, dx)
        else:
            self.node_proj = Linear(dx * 2, dx)

    def compute_attention(self, Q, K, V, E, e_mul, e_add, mask, name=""):
        """Helper function to compute attention with edge features.
        Works with both fragment-level (N=n) and atom-level (N=n*MAX_ATOMS) inputs.

        Args:
            Q: [bs, N, dx] Query, where N is either n (fragments) or n*MAX_ATOMS (atoms)
            K: [bs, N, dx] Key
            V: [bs, N, dx] Value
            E: [bs, N, N, de] Edge features
            e_mul, e_add: Edge feature projections
            mask: Tuple of (e_mask1, e_mask2) for masking
            name: Optional name for debugging
        """
        bs, N = Q.shape[:2]

        # Reshape to attention heads
        Q = Q.reshape(bs, N, self.n_head, self.df)
        K = K.reshape(bs, N, self.n_head, self.df)
        V = V.reshape(bs, N, self.n_head, self.df)

        Q = Q.unsqueeze(2)  # [bs, 1, N, n_head, df]
        K = K.unsqueeze(1)  # [bs, N, 1, n_head, df]
        V = V.unsqueeze(1)  # [bs, 1, N, n_head, df]

        # Compute attention scores
        Y = Q * K / math.sqrt(self.df)

        # Incorporate edge features
        E1 = e_mul(E) * mask[0] * mask[1]  # e_mask1 * e_mask2
        E2 = e_add(E) * mask[0] * mask[1]
        E1 = E1.reshape(bs, N, N, self.n_head, self.df)
        E2 = E2.reshape(bs, N, N, self.n_head, self.df)
        Y = Y * (E1 + 1) + E2

        # Apply attention
        softmax_mask = mask[1].expand(-1, N, -1, self.n_head)
        attn = masked_softmax(Y, softmax_mask, dim=2)

        weighted_V = attn * V
        weighted_V = weighted_V.sum(dim=2)
        weighted_V = weighted_V.flatten(start_dim=2)  # [bs, N, dx]

        return weighted_V, Y

    def forward(self, X, E, C, node_padding_mask):
        """
        Args:
            X: [bs, N, dx] Node features, where N can be either:
               - n for fragment-level features
               - n*MAX_ATOMS for atom-level features
            E: [bs, N, N, de] Edge features matching X's granularity
            C: [bs, N, dx] Coordinate features matching X's granularity
            node_padding_mask: [bs, N] Padding mask matching X's granularity

        Returns:
            newX: [bs, N, dx] Updated node features at same granularity as input
            newE: [bs, N, N, de] Updated edge features at same granularity as input
            newC: [bs, N, dx] Updated coordinate features at same granularity as input
        """
        bs, N, _ = X.shape  # N = n or n*MAX_ATOMS depending on granularity

        # Create masks
        x_mask = node_padding_mask.unsqueeze(-1)  # [bs, N, 1]
        e_mask1 = x_mask.unsqueeze(2)  # [bs, N, 1, 1]
        e_mask2 = x_mask.unsqueeze(1)  # [bs, 1, N, 1]
        e_masks = (e_mask1, e_mask2)

        # Combine node and coordinate features
        combined_features = torch.cat([X, C], dim=-1)

        # Self-attention on combined features
        Q = self.q_self(combined_features) * x_mask
        K = self.k_self(combined_features) * x_mask
        V = self.v_self(combined_features) * x_mask

        combined_out, Y_attn = self.compute_attention(
            Q, K, V, E, self.e_mul_self, self.e_add_self, e_masks, "combined_self"
        )
        del Q, K, V

        # Project the output
        combined_out = self.combined_out(combined_out) * x_mask

        # Update edge features using attention scores
        newE = self.edge_update(Y_attn.flatten(start_dim=3)) * e_mask1 * e_mask2

        if self.mode == "atom":
            # Split combined features back to node and coordinate
            node_features, coord_features = torch.split(combined_out, self.dx, dim=-1)
            newX = self.node_proj(node_features) * x_mask
            newC = self.coord_proj(coord_features) * x_mask
            return newX, newE, newC
        else:
            # For fragment mode, project entire tensor
            newX = self.node_proj(combined_out) * x_mask
            return newX, newE


class GTFinalLayer(nn.Module):
    def __init__(self, hidden_size, out_channels, cond_dim):
        """
        Args:
            hidden_size: Dimension of hidden layer
            out_channels: Number of output channels
            cond_dim: Dimension of conditioning vector
        """
        super().__init__()
        self.norm_final = LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)

    def forward(self, x, c):
        # Get modulation parameters and reshape based on x's dimensions
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)

        # Add necessary number of dimensions for broadcasting
        for _ in range(x.dim() - shift.dim()):
            shift = shift.unsqueeze(1)
            scale = scale.unsqueeze(1)

        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class GraphTransformer(nn.Module):
    """
    n_layers : int -- number of layers
    dims : dict -- contains dimensions for each feature type
    """

    def __init__(self, config):
        super().__init__()
        self.n_layers = config.model.n_layers
        self.dim_X = N_BUILDING_BLOCKS + 1
        self.dim_E = N_REACTIONS * N_CENTERS * N_CENTERS + 2
        self.hidden_dims = config.model.hidden_dims
        self.act_fn_in = config.model.act_fn_in
        self.act_fn_out = config.model.act_fn_out
        self.cond_dim = config.model.cond_dim
        self.maccs_dim = 167

        self.sigma_map = TimestepEmbedding(config.model.cond_dim)
        self.final_layer_X = GTFinalLayer(self.hidden_dims.dx * 2, self.dim_X, self.cond_dim)
        self.final_layer_E = GTFinalLayer(self.hidden_dims.de * 2, self.dim_E, self.cond_dim)
        self.final_layer_C = GTFinalLayer(self.hidden_dims.dx, 3, self.cond_dim)
        activation_functions = {"ReLU": nn.ReLU(), "LeakyReLU": nn.LeakyReLU()}

        self.act_fn_in = activation_functions.get(self.act_fn_in, self.act_fn_in)
        self.act_fn_out = activation_functions.get(self.act_fn_out, self.act_fn_out)

        # Hierarchical positional encoding instead of shared positional encoding
        self.pe = HierarchicalPositionalEmbedding(
            d_model=self.hidden_dims.dx,
            max_fragments=5,  # Adjust based on your data
            max_atoms_per_fragment=MAX_ATOMS,
        )

        self.mlp_in_X = nn.Sequential(
            nn.Linear(self.maccs_dim, self.hidden_dims.dx),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.dx, self.hidden_dims.dx),
            self.act_fn_in,
        )

        self.mlp_in_X_frag = nn.Sequential(
            nn.Linear(self.maccs_dim, self.hidden_dims.dx),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.dx, self.hidden_dims.dx),
            self.act_fn_in,
        )

        self.mlp_in_X_onehot = nn.Sequential(
            nn.Linear(self.dim_X, self.hidden_dims.dx),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.dx, self.hidden_dims.dx),
            self.act_fn_in,
        )

        self.mlp_in_E = nn.Sequential(
            nn.Linear(N_BOND_FEATURES, self.hidden_dims.de),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.de, self.hidden_dims.de),
            self.act_fn_in,
        )

        self.mlp_in_E_frag = nn.Sequential(
            nn.Linear(self.dim_E, self.hidden_dims.de),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.de, self.hidden_dims.de),
            self.act_fn_in,
        )

        self.mlp_in_C = nn.Sequential(
            nn.Linear(3, self.hidden_dims.dx),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.dx, self.hidden_dims.dx),
            self.act_fn_in,
        )

        self.tf_layers = nn.ModuleList(
            [
                XEyTransformerLayer(
                    dx=self.hidden_dims.dx,
                    de=self.hidden_dims.de,
                    n_head=self.hidden_dims.n_head,
                    dim_ffX=self.hidden_dims.dim_ffX,
                    dim_ffE=self.hidden_dims.dim_ffE,
                )
                for i in range(self.n_layers)
            ]
        )

        self.X_atomtofrag_out = nn.Conv1d(
            in_channels=self.hidden_dims.dx,
            out_channels=self.hidden_dims.dx,
            kernel_size=MAX_ATOMS,
            stride=MAX_ATOMS,
            padding=0,
        )
        self.E_atomtofrag_out = nn.Conv2d(
            in_channels=self.hidden_dims.de,
            out_channels=self.hidden_dims.de,
            kernel_size=MAX_ATOMS,
            stride=MAX_ATOMS,
            padding=0,
        )

    def forward(self, X, E, C, node_padding_mask, sigma):
        print(X.dtype, E.dtype, C.dtype)
        # Timestep embedding
        c = F.silu(self.sigma_map(sigma))
        bs, n = X.shape[0], X.shape[1]
        X_indices = X.argmax(dim=-1)
        bond_feats = get_partial_bond_features(
            X_indices, node_padding_mask, mode="feats"
        )  # bs, n*MAX_ATOMS, n*MAX_ATOMS, 5
        # Node padding mask - expand
        node_padding_mask = node_padding_mask.unsqueeze(-1)  # bs, n, 1
        node_padding_mask = node_padding_mask.expand(-1, -1, MAX_ATOMS)  # bs, n, MAX_ATOMS
        node_padding_mask = node_padding_mask.reshape(bs, n * MAX_ATOMS)  # bs, n*MAX_ATOMS

        # Create diagonal mask for edges
        diag_mask = torch.eye(n)
        diag_mask = ~diag_mask.type_as(E).bool()
        diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)

        # Calculate fragment features
        X_frag = get_partial_maccs_keys(X_indices)  # bs, n, 167
        X_atom = X_frag.unsqueeze(2).expand(-1, -1, MAX_ATOMS, -1)  # bs, n, MAX_ATOMS, 167
        X_atom = X_atom.reshape(bs, n * MAX_ATOMS, self.maccs_dim)  # bs, n*MAX_ATOMS, 167

        # Apply hierarchical positional encoding
        fragment_pe, atom_pe = self.pe(X_frag, X_atom)
        X_atom = self.mlp_in_X(X_atom) + atom_pe  # bs, n*MAX_ATOMS, hidden_dims.dx
        X_frag = self.mlp_in_X_frag(X_frag) + fragment_pe

        X_onehot = self.mlp_in_X_onehot(X)

        # Flatten C and combine with positional encoding
        C = C.view(C.shape[0], C.shape[1] * MAX_ATOMS, 3)  # bs, n*MAX_ATOMS, 3
        C = self.mlp_in_C(C) + atom_pe  # bs, n*MAX_ATOMS, hidden_dims.dx

        # Prepare edge features
        E_frag = self.mlp_in_E_frag(E)  # Fragment-level edges (bs, n, n, hidden_dims.de)
        E_atom = self.mlp_in_E(
            bond_feats
        )  # Atom-level edges (bs, n*MAX_ATOMS, n*MAX_ATOMS, hidden_dims.de)

        # Make edges symmetric
        E_frag = (E_frag + E_frag.transpose(1, 2)) / 2
        E_atom = (E_atom + E_atom.transpose(1, 2)) / 2

        # Process through transformer layers
        # X onehot is just conditioning information
        for layer in self.tf_layers:
            X_atom, X_frag, E_atom, E_frag, C = layer(
                X_atom, X_frag, X_onehot, E_atom, E_frag, C, node_padding_mask
            )

        # Convert atom-level features to fragment-level
        X_atom = X_atom.permute(0, 2, 1)
        X_atom = self.X_atomtofrag_out(X_atom)  # (bs, dx, n)
        X_atom = X_atom.permute(0, 2, 1)  # (bs, n, dx)

        # Convert atom-level edges to fragment-level
        E_atom = E_atom.permute(0, 3, 1, 2)  # (bs, de, n*MAX_ATOMS, n*MAX_ATOMS)
        E_atom = self.E_atomtofrag_out(E_atom)  # (bs, de, n, n)
        E_atom = E_atom.permute(0, 2, 3, 1)  # (bs, n, n, de)

        # Final fragment-level features combine atom and fragment level information
        X_combined = torch.cat([X_atom, X_frag], dim=-1)  # (bs, n, dx*2)
        X_output = self.final_layer_X(X_combined, c)  # (bs, n, dim_X)

        # Final fragment-level edge features combine atom and fragment level information
        E_combined = torch.cat([E_atom, E_frag], dim=-1)  # (bs, n, n, de*2)
        E_output = self.final_layer_E(E_combined, c)  # (bs, n, n, dim_E)

        # Reshape coordinates back to original shape
        C = self.final_layer_C(C, c)  # bs, n*MAX_ATOMS, 3
        C = C.view(C.shape[0], n, MAX_ATOMS, 3)  # bs, n, MAX_ATOMS, 3

        # Symmetrize final edge features
        E_output = E_output * diag_mask
        E_output = 1 / 2 * (E_output + torch.transpose(E_output, 1, 2))
        return X_output, E_output, C
