import math

import torch
import torch.nn as nn
import torch_geometric
from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch_geometric.utils import to_dense_adj, to_dense_batch

from src.constants import (
    MAX_ATOMS,
    N_ATOM_FEATURES,
    N_BOND_FEATURES,
    N_BUILDING_BLOCKS,
    N_CENTERS,
    N_REACTIONS,
)
from src.models.layers import (
    SimplePositionalEmbedding,
    SinusoidalPositionalEmbedding,
    TimestepEmbedding,
)
from src.models.utils import (
    PlaceHolder,
    assert_correctly_masked,
    encode_no_edge,
    masked_softmax,
    mean_pool_fragments,
    modulate,
    to_dense,
)
from src.utils.indexing_utils import (
    get_partial_atom_features,
    get_partial_bond_features,
    get_partial_maccs_keys,
)


class XEyTransformerLayer(nn.Module):
    """Transformer that updates node, coordinate and edge features
    dx: node features dimension
    de: edge features dimension
    n_head: the number of heads in the multi_head_attention
    dim_ffX: dimension of node feedforward network
    dim_ffE: dimension of edge feedforward network
    dropout: dropout probability. 0 to disable
    layer_norm_eps: eps value in layer normalizations.
    """

    def __init__(
        self,
        dx: int,
        de: int,
        n_head: int,
        dim_ffX: int = 2048,
        dim_ffE: int = 128,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5,
        device=None,
        dtype=None,
    ) -> None:
        kw = {"device": device, "dtype": dtype}
        super().__init__()

        self.self_attn = NodeEdgeBlock(dx, de, n_head, **kw)

        # Node feature layers
        self.linX1 = Linear(dx, dim_ffX, **kw)
        self.linX2 = Linear(dim_ffX, dx, **kw)
        self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.dropoutX1 = Dropout(dropout)
        self.dropoutX2 = Dropout(dropout)
        self.dropoutX3 = Dropout(dropout)

        # Edge feature layers
        self.linE1 = Linear(de, dim_ffE, **kw)
        self.linE2 = Linear(dim_ffE, de, **kw)
        self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.dropoutE1 = Dropout(dropout)
        self.dropoutE2 = Dropout(dropout)
        self.dropoutE3 = Dropout(dropout)

        # Coordinate feature layers
        self.linC1 = Linear(dx, dim_ffX, **kw)
        self.linC2 = Linear(dim_ffX, dx, **kw)
        self.normC1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normC2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.dropoutC1 = Dropout(dropout)
        self.dropoutC2 = Dropout(dropout)
        self.dropoutC3 = Dropout(dropout)

        self.activation = F.relu

    def forward(
        self,
        X: Tensor,
        E: Tensor,
        C: Tensor,
        atom_features: Tensor,
        node_padding_mask: Tensor,
        denoised_mask: Tensor,
    ):
        """Pass the input through the encoder layer.
        X: Node features (bs, n*MAX_ATOMS, dx)
        E: Edge features (bs, n*MAX_ATOMS, n*MAX_ATOMS, de)
        C: Coordinate features (bs, n*MAX_ATOMS, dx)
        atom_features: (bs, n*MAX_ATOMS, atom_feat_dim) Per-atom features
        node_padding_mask: (bs, n*MAX_ATOMS) Mask for the src keys per batch (optional)
        denoised_mask: (bs, n*MASKED_ATOMS) Mask for denoised fragments
        Output: newX, newE, newC with the same shapes.
        """

        # Self attention
        newX, newE, newC = self.self_attn(
            X,
            E,
            C,
            atom_features=atom_features,
            node_padding_mask=node_padding_mask,
            denoised_mask=denoised_mask,
        )

        # Node residual + norm
        newX_d = self.dropoutX1(newX)
        X = self.normX1(X + newX_d)

        # Edge residual + norm
        newE_d = self.dropoutE1(newE)
        E = self.normE1(E + newE_d)

        # Coord residual + norm
        newC_d = self.dropoutC1(newC)
        C = self.normC1(C + newC_d)

        # Node feedforward
        ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X))))
        ff_outputX = self.dropoutX3(ff_outputX)
        X = self.normX2(X + ff_outputX)

        # Edge feedforward
        ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E))))
        ff_outputE = self.dropoutE3(ff_outputE)
        E = self.normE2(E + ff_outputE)

        # Coord feedforward
        ff_outputC = self.linC2(self.dropoutC2(self.activation(self.linC1(C))))
        ff_outputC = self.dropoutC3(ff_outputC)
        C = self.normC2(C + ff_outputC)

        return X, E, C


class NodeEdgeBlock(nn.Module):
    """Self and cross attention layer that updates both node features and coordinates."""

    def __init__(self, dx, de, n_head, **kwargs):
        super().__init__()
        assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}"
        self.dx = dx
        self.de = de
        self.df = int(dx / n_head)
        self.n_head = n_head

        # Shared positional encoding
        self.pe = SinusoidalPositionalEmbedding(
            d_model=dx, max_len=MAX_ATOMS * 5
        )  # TODO: Should not be hardcoded

        # Denoised self-attention modulation
        self.q_denoised = Linear(dx, dx)
        self.k_denoised = Linear(dx, dx)
        self.v_denoised = Linear(dx, dx)
        self.e_mul_denoised = Linear(de, dx)
        self.e_add_denoised = Linear(de, dx)

        # X→X self-attention
        self.q_x_self = Linear(dx, dx)
        self.k_x_self = Linear(dx, dx)
        self.v_x_self = Linear(dx, dx)
        self.e_mul_x_self = Linear(de, dx)
        self.e_add_x_self = Linear(de, dx)

        # C→C self-attention
        self.q_c_self = Linear(dx, dx)
        self.k_c_self = Linear(dx, dx)
        self.v_c_self = Linear(dx, dx)
        self.e_mul_c_self = Linear(de, dx)
        self.e_add_c_self = Linear(de, dx)

        # X→C cross-attention
        self.q_x_to_c = Linear(dx, dx)
        self.k_c_for_x = Linear(dx, dx)
        self.v_c_for_x = Linear(dx, dx)
        self.e_mul_x_to_c = Linear(de, dx)
        self.e_add_x_to_c = Linear(de, dx)

        # C→X cross-attention
        self.q_c_to_x = Linear(dx, dx)
        self.k_x_for_c = Linear(dx, dx)
        self.v_x_for_c = Linear(dx, dx)
        self.e_mul_c_to_x = Linear(de, dx)
        self.e_add_c_to_x = Linear(de, dx)

        # Output combinations
        self.combine_x = nn.Sequential(nn.Linear(2 * dx, dx), nn.ReLU(), nn.Linear(dx, dx))
        self.combine_c = nn.Sequential(nn.Linear(2 * dx, dx), nn.ReLU(), nn.Linear(dx, dx))

        # Final output projections
        self.x_out = Linear(dx, dx)
        self.c_out = Linear(dx, dx)

        # Edge update from all attentions
        self.edge_update = nn.Sequential(
            nn.Linear(4 * dx, 2 * de), nn.ReLU(), nn.Linear(2 * de, de)
        )

    def compute_attention(self, Q, K, V, E, e_mul, e_add, mask, name=""):
        """Helper function to compute attention with edge features.

        Args:
            Q: [bs, n*MAX_ATOMS, dx] Query
            K: [bs, n*MAX_ATOMS, dx] Key
            V: [bs, n*MAX_ATOMS, dx] Value
            E: [bs, n*MAX_ATOMS, n*MAX_ATOMS, de] Edge features
            e_mul, e_add: Edge feature projections
            mask: Tuple of (e_mask1, e_mask2) for masking
            name: Optional name for debugging
        """
        bs, nm = Q.shape[:2]

        # Reshape to attention heads
        Q = Q.reshape(bs, nm, self.n_head, self.df)
        K = K.reshape(bs, nm, self.n_head, self.df)
        V = V.reshape(bs, nm, self.n_head, self.df)

        Q = Q.unsqueeze(2)  # [bs, 1, n*MAX_ATOMS, n_head, df]
        K = K.unsqueeze(1)  # [bs, n*MAX_ATOMS, 1, n_head, df]
        V = V.unsqueeze(1)  # [bs, 1, n*MAX_ATOMS, n_head, df]

        # Compute attention scores
        Y = Q * K / math.sqrt(self.df)

        # Incorporate edge features
        E1 = e_mul(E) * mask[0] * mask[1]  # e_mask1 * e_mask2
        E2 = e_add(E) * mask[0] * mask[1]
        E1 = E1.reshape(bs, nm, nm, self.n_head, self.df)
        E2 = E2.reshape(bs, nm, nm, self.n_head, self.df)
        Y = Y * (E1 + 1) + E2

        # Apply attention
        softmax_mask = mask[1].expand(-1, nm, -1, self.n_head)
        attn = masked_softmax(Y, softmax_mask, dim=2)

        weighted_V = attn * V
        weighted_V = weighted_V.sum(dim=2)
        weighted_V = weighted_V.flatten(start_dim=2)  # [bs, n*MAX_ATOMS, dx]

        return weighted_V, Y

    def forward(self, X, E, C, atom_features, node_padding_mask, denoised_mask):
        """
        Args:
            X: [bs, n*MAX_ATOMS, dx] Node features
            E: [bs, n, n, de] Edge features
            C: [bs, n*MAX_ATOMS, dx] Coordinate features
            atom_features: [bs, n*MAX_ATOMS, atom_feat_dim] Per-atom features
            node_padding_mask: [bs, n*MAX_ATOMS] Atom-level padding mask
            denoised_mask: [bs, n*MAX_ATOMS] mask for denoised fragments
        """
        bs, nm, _ = X.shape  # nm = n * MAX_ATOMS

        # Apply positional encoding
        XC_pos = torch.arange(nm, device=X.device)
        X = X + self.pe(XC_pos)
        C = C + self.pe(XC_pos)

        # Create masks
        x_mask = node_padding_mask.unsqueeze(-1)  # [bs, n*MAX_ATOMS, 1]
        e_mask1 = x_mask.unsqueeze(2)  # [bs, n*MAX_ATOMS, 1, 1]
        e_mask2 = x_mask.unsqueeze(1)  # [bs, 1, n*MAX_ATOMS, 1]
        e_masks = (e_mask1, e_mask2)

        # Optional: Denoised coordinates cross-attention with features

        # 0. atom_features→C cross-attention
        d_mask = denoised_mask.unsqueeze(-1)
        d_masks = (d_mask.unsqueeze(2), d_mask.unsqueeze(1))

        Q_d = self.q_denoised(atom_features) * d_mask
        K_d = self.k_denoised(C) * d_mask
        V_d = self.v_denoised(C) * d_mask
        C_denoised, _ = self.compute_attention(
            Q_d, K_d, V_d, E, self.e_mul_denoised, self.e_add_denoised, d_masks, "denoised"
        )
        del Q_d, K_d, V_d
        C = C * (~d_mask) + C_denoised * d_mask
        del C_denoised

        # 1. X→X self-attention
        Q_x = self.q_x_self(X) * x_mask
        K_x = self.k_x_self(X) * x_mask
        V_x = self.v_x_self(X) * x_mask
        x_self, Y_xx = self.compute_attention(
            Q_x, K_x, V_x, E, self.e_mul_x_self, self.e_add_x_self, e_masks, "x_self"
        )
        del Q_x, K_x, V_x

        # 2. C→C self-attention
        Q_c = self.q_c_self(C) * x_mask
        K_c = self.k_c_self(C) * x_mask
        V_c = self.v_c_self(C) * x_mask
        c_self, Y_cc = self.compute_attention(
            Q_c, K_c, V_c, E, self.e_mul_c_self, self.e_add_c_self, e_masks, "c_self"
        )
        del Q_c, K_c, V_c

        # 3. X→C cross-attention
        Q_x_to_c = self.q_x_to_c(X) * x_mask
        K_c_for_x = self.k_c_for_x(C) * x_mask
        V_c_for_x = self.v_c_for_x(C) * x_mask
        x_to_c, Y_xc = self.compute_attention(
            Q_x_to_c,
            K_c_for_x,
            V_c_for_x,
            E,
            self.e_mul_x_to_c,
            self.e_add_x_to_c,
            e_masks,
            "x_to_c",
        )
        del Q_x_to_c, K_c_for_x, V_c_for_x

        # 4. C→X cross-attention
        Q_c_to_x = self.q_c_to_x(C) * x_mask
        K_x_for_c = self.k_x_for_c(X) * x_mask
        V_x_for_c = self.v_x_for_c(X) * x_mask
        c_to_x, Y_cx = self.compute_attention(
            Q_c_to_x,
            K_x_for_c,
            V_x_for_c,
            E,
            self.e_mul_c_to_x,
            self.e_add_c_to_x,
            e_masks,
            "c_to_x",
        )
        del Q_c_to_x, K_x_for_c, V_x_for_c

        # Combine attention outputs
        newX = self.combine_x(torch.cat([x_self, c_to_x], dim=-1))
        newX = self.x_out(newX) * x_mask

        newC = self.combine_c(torch.cat([c_self, x_to_c], dim=-1))
        newC = self.c_out(newC) * x_mask

        # Update edge features using all attention scores
        Y_combined = torch.cat(
            [
                Y_xx.flatten(start_dim=3),
                Y_cc.flatten(start_dim=3),
                Y_xc.flatten(start_dim=3),
                Y_cx.flatten(start_dim=3),
            ],
            dim=-1,
        )

        newE = self.edge_update(Y_combined) * e_mask1 * e_mask2

        return newX, newE, newC


class GTFinalLayer(nn.Module):
    def __init__(self, hidden_size, out_channels, cond_dim):
        """
        Args:
            hidden_size: Dimension of hidden layer
            out_channels: Number of output channels
            cond_dim: Dimension of conditioning vector
        """
        super().__init__()
        self.norm_final = LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)

    def forward(self, x, c):
        # Get modulation parameters and reshape based on x's dimensions
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)

        # Add necessary number of dimensions for broadcasting
        for _ in range(x.dim() - shift.dim()):
            shift = shift.unsqueeze(1)
            scale = scale.unsqueeze(1)

        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class GraphTransformer(nn.Module):
    """
    n_layers : int -- number of layers
    dims : dict -- contains dimensions for each feature type
    """

    def __init__(self, config):
        super().__init__()
        self.n_layers = config.model.n_layers
        self.dim_X = N_BUILDING_BLOCKS + 1
        self.dim_E = N_REACTIONS * N_CENTERS * N_CENTERS + 2
        self.hidden_mlp_dims = config.model.hidden_mlp_dims
        self.hidden_dims = config.model.hidden_dims
        self.act_fn_in = config.model.act_fn_in
        self.act_fn_out = config.model.act_fn_out
        self.cond_dim = config.model.cond_dim
        self.maccs_dim = 167

        self.sigma_map = TimestepEmbedding(config.model.cond_dim)
        self.final_layer_X = GTFinalLayer(self.hidden_dims.dx, self.dim_X, self.cond_dim)
        self.final_layer_E = GTFinalLayer(self.hidden_dims.de, self.dim_E, self.cond_dim)
        self.final_layer_C = GTFinalLayer(self.hidden_dims.dx, 3, self.cond_dim)
        activation_functions = {"ReLU": nn.ReLU(), "LeakyReLU": nn.LeakyReLU()}

        self.act_fn_in = activation_functions.get(self.act_fn_in, self.act_fn_in)
        self.act_fn_out = activation_functions.get(self.act_fn_out, self.act_fn_out)

        self.mlp_in_X = nn.Sequential(
            nn.Linear(self.maccs_dim, self.hidden_dims.dx),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.dx, self.hidden_dims.dx),
            self.act_fn_in,
        )

        self.mlp_in_atomfeat = nn.Sequential(
            nn.Linear(N_ATOM_FEATURES, self.hidden_dims.dx),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.dx, self.hidden_dims.dx),
            self.act_fn_in,
        )

        self.mlp_in_E = nn.Sequential(
            nn.Linear(N_BOND_FEATURES, self.hidden_dims.de),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.de, self.hidden_dims.de),
            self.act_fn_in,
        )

        self.mlp_in_C = nn.Sequential(
            nn.Linear(3, self.hidden_dims.dx),
            self.act_fn_in,
            nn.Linear(self.hidden_dims.dx, self.hidden_dims.dx),
            self.act_fn_in,
        )

        self.tf_layers = nn.ModuleList(
            [
                XEyTransformerLayer(
                    dx=self.hidden_dims.dx,
                    de=self.hidden_dims.de,
                    n_head=self.hidden_dims.n_head,
                    dim_ffX=self.hidden_dims.dim_ffX,
                    dim_ffE=self.hidden_dims.dim_ffE,
                )
                for i in range(self.n_layers)
            ]
        )

        self.mlp_out_X = nn.Sequential(
            nn.Linear(self.dim_X, self.dim_X),
            nn.ReLU(),
            nn.Linear(self.dim_X, self.dim_X),
            nn.ReLU(),
            nn.Linear(self.dim_X, self.dim_X),
        )

        self.mlp_out_E = nn.Sequential(
            nn.Linear(self.dim_E, self.dim_E),
            nn.ReLU(),
            nn.Linear(self.dim_E, self.dim_E),
            nn.ReLU(),
            nn.Linear(self.dim_E, self.dim_E),
        )

    def forward(self, X, E, C, node_padding_mask, sigma):
        # Timestep embedding
        c = F.silu(self.sigma_map(sigma))
        bs, n = X.shape[0], X.shape[1]
        X_indices = X.argmax(dim=-1)
        bond_feats = get_partial_bond_features(
            X_indices, node_padding_mask, mode="feats"
        )  # bs, n*MAX_ATOMS, n*MAX_ATOMS, 5
        denoised_mask = node_padding_mask  # & (X_indices != N_BUILDING_BLOCKS)  # bs, n

        # Node padding mask - expand
        node_padding_mask = node_padding_mask.unsqueeze(-1)  # bs, n, 1
        node_padding_mask = node_padding_mask.expand(-1, -1, MAX_ATOMS)  # bs, n, MAX_ATOMS
        node_padding_mask = node_padding_mask.reshape(bs, n * MAX_ATOMS)  # bs, n*MAX_ATOMS

        # Denoised mask - expand
        denoised_mask = denoised_mask.unsqueeze(-1)  # bs, n, 1
        denoised_mask = denoised_mask.expand(-1, -1, MAX_ATOMS)  # bs, n, MAX_ATOMS
        denoised_mask = denoised_mask.reshape(bs, n * MAX_ATOMS)  # bs, n*MAX_ATOMS

        # Create diagonal mask for edges
        diag_mask = torch.eye(n)
        diag_mask = ~diag_mask.type_as(E).bool()
        diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)

        # Expansion begins here
        atom_feats = get_partial_atom_features(X_indices)  # bs, n, MAX_ATOMS, 6
        atom_feats = atom_feats.reshape(bs, n * MAX_ATOMS, -1)
        atom_feats = self.mlp_in_atomfeat(atom_feats)

        X = get_partial_maccs_keys(X_indices)  # bs, n, 167
        X = X.unsqueeze(2).expand(-1, -1, MAX_ATOMS, -1)  # bs, n, MAX_ATOMS, 167
        X = X.reshape(bs, n * MAX_ATOMS, self.maccs_dim)  # bs, n*MAX_ATOMS, 167
        X = self.mlp_in_X(X)  # bs, n*MAX_ATOMS, hidden_dims.dx

        # Expand E to handle MAX_ATOMS dimension
        # E = E.unsqueeze(3).unsqueeze(4)
        # E = E.expand(-1, -1, -1, MAX_ATOMS, MAX_ATOMS, -1)  # bs, n, n, MAX_ATOMS, MAX_ATOMS, dim_E
        # E = E.reshape(
        #     bs, n * MAX_ATOMS, n * MAX_ATOMS, self.dim_E
        # )  # bs, n*MAX_ATOMS, n*MAX_ATOMS, dim_E

        # Flatten C and combine with conditioning
        C = C.view(C.shape[0], C.shape[1] * MAX_ATOMS, 3)  # bs, n*MAX_ATOMS, 3
        C = self.mlp_in_C(C)  # bs, n*MAX_ATOMS, hidden_dims.dx

        # Modulate edge features with conditioning
        E = self.mlp_in_E(bond_feats)  # bs, n, n, hidden_dims.de
        E = (E + E.transpose(1, 2)) / 2  # bs, n, n, hidden_dims.de

        # Process through transformer layers
        for layer in self.tf_layers:
            X, E, C = layer(X, E, C, atom_feats, node_padding_mask, denoised_mask)

        # Final layers
        X = self.final_layer_X(X, c)  # bs, n*MAX_ATOMS, hidden_dims.dx
        E = self.final_layer_E(E, c)  # bs, n*MAX_ATOMS, n*MAX_ATOMS, hidden_dims.de
        C = self.final_layer_C(C, c)  # bs, n*MAX_ATOMS, 3

        # Mean pool back to original shapes
        X, E = mean_pool_fragments(X, E, n, MAX_ATOMS)

        X = self.mlp_out_X(X)
        E = self.mlp_out_E(E)

        C = C.view(C.shape[0], n, MAX_ATOMS, 3)  # bs, n, MAX_ATOMS, 3

        # Symmetrize E
        E = E * diag_mask
        E = 1 / 2 * (E + torch.transpose(E, 1, 2))

        return X, E, C
