import math

import torch
import torch.nn as nn
import torch_geometric
from torch.nn import functional as F
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch_geometric.utils import to_dense_adj, to_dense_batch

from src.constants import MAX_ATOMS, N_BUILDING_BLOCKS, N_CENTERS, N_REACTIONS
from src.models.layers import (
    CtoE,
    CtoX,
    EtoC,
    EtoX,
    PositionsMLP,
    SE3Norm,
    SimplePositionalEmbedding,
    TimestepEmbedding,
)
from src.models.utils import (
    PlaceHolder,
    assert_correctly_masked,
    encode_no_edge,
    masked_softmax,
    modulate,
    to_dense,
)
from src.utils.spatial_utils import center_atom_coords


class XEyTransformerLayer(nn.Module):
    """Transformer that updates node, edge and global features
    d_x: node features
    d_e: edge features
    n_head: the number of heads in the multi_head_attention
    dim_feedforward: the dimension of the feedforward network model after self-attention
    dropout: dropout probablility. 0 to disable
    layer_norm_eps: eps value in layer normalizations.
    """

    def __init__(
        self,
        dx: int,
        de: int,
        n_head: int,
        dim_ffX: int = 2048,
        dim_ffE: int = 128,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5,
        device=None,
        dtype=None,
    ) -> None:
        kw = {"device": device, "dtype": dtype}
        super().__init__()

        self.self_attn = NodeEdgeBlock(dx, de, n_head, **kw)

        self.linX1 = Linear(dx, dim_ffX, **kw)
        self.linX2 = Linear(dim_ffX, dx, **kw)
        self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.dropoutX1 = Dropout(dropout)
        self.dropoutX2 = Dropout(dropout)
        self.dropoutX3 = Dropout(dropout)

        self.norm_pos1 = SE3Norm(eps=layer_norm_eps, **kw)

        self.linE1 = Linear(de, dim_ffE, **kw)
        self.linE2 = Linear(dim_ffE, de, **kw)
        self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.dropoutE1 = Dropout(dropout)
        self.dropoutE2 = Dropout(dropout)
        self.dropoutE3 = Dropout(dropout)

        self.activation = F.relu

    def forward(self, X, E, C, node_padding_mask, pos_mask):
        """Pass the input through the encoder layer.
        X: (bs, n, dx)                     node features
        E: (bs, n, n, de)                  edge features
        C: (bs, n * MAX_ATOMS, 3)          atomic coordinates
        node_padding_mask: bs, n                 mask for nodes
        pos_mask: bs, n * MAX_ATOMS      mask for atomic coordinates
        """
        # Pass through self-attention update block
        newX, newE, vel = self.self_attn(
            X, E, C, node_padding_mask=node_padding_mask, pos_mask=pos_mask
        )

        # Apply dropout, layer norm and feedforward layers to each feature type
        newX_d = self.dropoutX1(newX)
        X = self.normX1(X + newX_d)

        # Add the velocity to the atomic coordinates to update them
        C = self.norm_pos1(vel, pos_mask) + C

        newE_d = self.dropoutE1(newE)
        E = self.normE1(E + newE_d)

        ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X))))
        ff_outputX = self.dropoutX3(ff_outputX)
        X = self.normX2(X + ff_outputX)

        ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E))))
        ff_outputE = self.dropoutE3(ff_outputE)
        E = self.normE2(E + ff_outputE)
        E = 0.5 * (E + E.transpose(1, 2))

        return X, E, C


class NodeEdgeBlock(nn.Module):
    """Self attention layer that also updates the representations on the edges."""

    def __init__(self, dx, de, n_head, **kwargs):
        super().__init__()
        assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}"
        self.dx = dx
        self.de = de
        self.df = int(dx / n_head)
        self.n_head = n_head

        # Extracting SE(3) features from coordinates
        self.lin_dist1 = Linear(2, de)
        self.lin_norm_pos1 = Linear(1, de)
        self.lin_norm_pos2 = Linear(1, de)

        # Edge-Level Feature Update
        self.in_E = Linear(de, de)
        self.x_e_mul1 = Linear(dx, de)  # X -> E
        self.x_e_mul2 = Linear(dx, de)
        self.dist_add_e = CtoE(de, de)  # dist1 -> E (TODO: change de to dc)
        self.dist_mul_e = CtoE(de, de)
        self.e_out = Linear(de, de)

        # Node-Level Self-Attention
        self.k = Linear(dx, dx)
        self.q = Linear(dx, dx)
        self.v = Linear(dx, dx)
        self.a = Linear(dx, n_head, bias=False)
        self.out = Linear(dx * n_head, dx)
        self.e_att_mul = Linear(de, n_head)
        self.pos_att_mul = CtoE(de, n_head)  # dist1 -> n_head

        # Node-Level Feature Update
        self.e_x_mul = EtoX(de, dx)  # E -> X
        self.pos_x_mul = CtoX(de, dx)  # C -> X
        self.x_out = Linear(dx, dx)

        # Coordinates Update Conditioned on Edge Features
        self.e_pos1 = EtoC(de, de)  # E -> C
        self.e_pos2 = Linear(de, 1, bias=False)

        # Node and Edge level positional encodings
        self.pe_x = SimplePositionalEmbedding(dim=1, max_len=32, d_model=dx)
        self.pe_e = SimplePositionalEmbedding(dim=2, max_len=32, d_model=de)

    def forward(self, X, E, C, node_padding_mask, pos_mask):
        """
        :param X: bs, n, dx                     node features
        :param E: bs, n, n, de                  edge features
        :param C: bs, n * MAX_ATOMS, 3          atomic coordinates
        :param node_padding_mask: bs, n                 mask for nodes
        :param pos_mask: bs, n * MAX_ATOMS      mask for atomic coordinates
        :return: newX, newE, newC with updated representations
        """
        X = self.pe_x(X)  # Apply positional encoding to node features
        E = self.pe_e(E)  # Apply positional encoding to edge features

        bs, n, _ = X.shape
        nm = n * MAX_ATOMS
        x_mask = node_padding_mask.unsqueeze(-1)  # bs, n, 1
        e_mask1 = x_mask.unsqueeze(2)  # bs, n, 1, 1
        e_mask2 = x_mask.unsqueeze(1)  # bs, 1, n, 1
        d_mask = pos_mask.unsqueeze(-1)  # bs, nm, 1
        c_mask1 = d_mask.unsqueeze(2)  # bs, nm, 1, 1
        c_mask2 = d_mask.unsqueeze(1)  # bs, 1, nm, 1

        ############################################################################
        # Extract SE(3)-equivariant 3d info about atomic coordinates including norms
        # pair-wise distances and cosine similarities
        ############################################################################
        pos = C * pos_mask.unsqueeze(-1)  # bs, nm, 3
        norm_pos = torch.norm(pos, dim=-1, keepdim=True)  # bs, nm, 1
        normalized_pos = pos / (norm_pos + 1e-7)  # bs, nm, 3

        # Compute pairwise distance and cosine sim, then concat into a tensor of shape (bs, nm, nm, 2)
        pairwise_dist = torch.cdist(pos, pos).unsqueeze(-1).float()
        cosines = torch.sum(
            normalized_pos.unsqueeze(1) * normalized_pos.unsqueeze(2), dim=-1, keepdim=True
        )
        pos_info = torch.cat((pairwise_dist, cosines), dim=-1)  # bs, nm, nm, 2

        # Aggregate info about norms, pairwise distances and cosine similarities
        norm1 = self.lin_norm_pos1(norm_pos)  # bs, nm, dc
        norm2 = self.lin_norm_pos2(norm_pos)  # bs, nm, dc
        dist1 = F.relu(self.lin_dist1(pos_info) + norm1.unsqueeze(2) + norm2.unsqueeze(1))
        dist1 = dist1 * c_mask1 * c_mask2  # bs, nm, nm, dc

        ############################################################################
        # Compute edge embeddings conditioned on node, edge features and dist1
        ############################################################################
        Y = self.in_E(E)  # bs, n, n, de

        # 2.1 Incorporate X into E using FiLM (dx -> de)
        x_e_mul1 = self.x_e_mul1(X) * x_mask
        x_e_mul2 = self.x_e_mul2(X) * x_mask
        Y = Y * x_e_mul1.unsqueeze(1) * x_e_mul2.unsqueeze(2) * e_mask1 * e_mask2

        # 2.2. Incorporate dist1 (bs, n*m, n*m, de) into E (bs, n, n, de)
        dist_add = self.dist_add_e(dist1, c_mask1, c_mask2)
        dist_mul = self.dist_mul_e(dist1, c_mask1, c_mask2)
        Y = (Y + dist_add + Y * dist_mul) * e_mask1 * e_mask2  # bs, n, n, dx

        # 2.4 Pass through non-linearity
        Eout = self.e_out(Y) * e_mask1 * e_mask2  # bs, n, n, de

        ############################################################################
        # 3. Compute node-level self-attention conditioned on updated edge features and delta
        ############################################################################
        Q = (self.q(X) * x_mask).unsqueeze(2)  # bs, 1, n, dx
        K = (self.k(X) * x_mask).unsqueeze(1)  # bs, n, 1, dx
        prod = Q * K / math.sqrt(Y.size(-1))  # bs, n, n, dx
        a = self.a(prod) * e_mask1 * e_mask2  # bs, n, n, n_head

        # 3.1 Incorporate edge features into the attention scores
        e_x_mul = self.e_att_mul(E)
        a = a + e_x_mul * a

        # 3.2 Incorporate dist1 (bs, n*m, n*m, de) into attention scores (bs, n, n, n_head)
        pos_x_mul = self.pos_att_mul(dist1, c_mask1, c_mask2)
        a = a + pos_x_mul * a
        a = a * e_mask1 * e_mask2

        # 3.3 Standard self-attention over the fragments
        softmax_mask = e_mask2.expand(-1, n, -1, self.n_head)
        alpha = masked_softmax(a, softmax_mask, dim=2).unsqueeze(-1)  # bs, n, n, n_head
        V = (self.v(X) * x_mask).unsqueeze(1).unsqueeze(3)  # bs, 1, n, 1, dx
        weighted_V = alpha * V  # bs, n, n, n_heads, dx
        weighted_V = weighted_V.sum(dim=2)  # bs, n, n_head, dx
        weighted_V = weighted_V.flatten(start_dim=2)  # bs, n, n_head x dx
        weighted_V = self.out(weighted_V) * x_mask  # bs, n, dx

        ############################################################################
        # Update node features conditioned on updated edge features and delta
        ############################################################################
        e_x_mul = self.e_x_mul(E, e_mask2)  # E -> X
        weighted_V = weighted_V + e_x_mul * weighted_V

        # Incorporate dist1 (bs, n*m, n*m, de) into newX (bs, n, dx)
        pos_x_mul = self.pos_x_mul(dist1, c_mask1, c_mask2, e_mask2)
        weighted_V = weighted_V + pos_x_mul * weighted_V

        # Output newX by passing through a linear layer to obtain shape (bs, n, dx)
        Xout = self.x_out(weighted_V) * x_mask

        ############################################################################
        # Computing the positions update velocity conditioned on new edge features Y
        ############################################################################
        pos1 = pos.unsqueeze(1).expand(-1, nm, -1, -1)  # bs, 1, nm, 3
        pos2 = pos.unsqueeze(2).expand(-1, -1, nm, -1)  # bs, nm, 1, 3
        delta_pos = pos2 - pos1  # bs, nm, nm, 3

        # Incorporate edge features Y (bs, n, n, de) into the positional update
        pa = F.relu(self.e_pos1(Y, e_mask1, e_mask2))  # bs, nm, nm, dc
        messages = self.e_pos2(pa)  # bs, nm, nm, 1
        vel = (messages * delta_pos).sum(dim=2) * d_mask  # bs, nm, 3

        # Recenter the atomic coordinates
        vel = center_atom_coords(vel, pos_mask)
        return Xout, Eout, vel


class GTFinalLayer(nn.Module):
    def __init__(self, hidden_size, out_channels, cond_dim):
        """
        Args:
            hidden_size: Dimension of hidden layer
            out_channels: Number of output channels
            cond_dim: Dimension of conditioning vector
        """
        super().__init__()
        self.norm_final = LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.linear.weight.data.zero_()
        self.linear.bias.data.zero_()

        self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x, c):
        # Get modulation parameters and reshape based on x's dimensions
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)

        # Add necessary number of dimensions for broadcasting
        for _ in range(x.dim() - shift.dim()):
            shift = shift.unsqueeze(1)
            scale = scale.unsqueeze(1)

        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class GraphTransformer(nn.Module):
    """
    n_layers : int -- number of layers
    dims : dict -- contains dimensions for each feature type
    """

    def __init__(self, config):
        super().__init__()
        self.n_layers = config.model.n_layers
        self.dim_X = N_BUILDING_BLOCKS + 1
        self.dim_E = N_REACTIONS * N_CENTERS * N_CENTERS + 2
        self.hidden_mlp_dims = config.model.hidden_mlp_dims
        self.hidden_dims = config.model.hidden_dims
        self.act_fn_in = config.model.act_fn_in
        self.act_fn_out = config.model.act_fn_out
        self.cond_dim = config.model.cond_dim
        self.sigma_map = TimestepEmbedding(config.model.cond_dim)

        # Activation functions
        activation_functions = {"ReLU": nn.ReLU(), "LeakyReLU": nn.LeakyReLU()}
        self.act_fn_in = activation_functions.get(self.act_fn_in, self.act_fn_in)
        self.act_fn_out = activation_functions.get(self.act_fn_out, self.act_fn_out)

        # Input Embedding layers
        self.mlp_in_X = nn.Sequential(
            nn.Linear(self.dim_X, self.hidden_mlp_dims.X),
            self.act_fn_in,
            nn.Linear(self.hidden_mlp_dims.X, self.hidden_dims.dx),
            self.act_fn_in,
        )
        self.mlp_in_E = nn.Sequential(
            nn.Linear(self.dim_E, self.hidden_mlp_dims.E),
            self.act_fn_in,
            nn.Linear(self.hidden_mlp_dims.E, self.hidden_dims.de),
            self.act_fn_in,
        )
        self.mlp_in_C = PositionsMLP(self.hidden_mlp_dims.C)

        # Transformer Layers
        self.tf_layers = nn.ModuleList(
            [
                XEyTransformerLayer(
                    dx=self.hidden_dims.dx,
                    de=self.hidden_dims.de,
                    n_head=self.hidden_dims.n_head,
                    dim_ffX=self.hidden_dims.dim_ffX,
                    dim_ffE=self.hidden_dims.dim_ffE,
                )
                for i in range(self.n_layers)
            ]
        )

        # Final output layers
        self.final_layer_X = GTFinalLayer(self.hidden_dims.dx, self.dim_X, self.cond_dim)
        self.final_layer_E = GTFinalLayer(self.hidden_dims.de, self.dim_E, self.cond_dim)
        self.final_layer_C = PositionsMLP(self.hidden_dims.dc)

    def forward(self, X, E, C, sigma, node_padding_mask):
        """Forward pass through the graph transformer model.
        Args:
            X : torch.Tensor -- fragment features (bs, n, dim_X)
            E : torch.Tensor -- edge features (bs, n, n, dim_E)
            C : torch.Tensor -- atomic coordinates (bs, n, MAX_ATOMS, 3)
            node_padding_mask : torch.Tensor -- mask for nodes (bs, n)
        """
        bs, n = X.shape[0], X.shape[1]

        # Pre. Flatten C to shape (bs, n * MAX_ATOMS, 3), and create pos_mask based on node_padding_mask
        C = C.view(bs, n * MAX_ATOMS, 3)
        bs, n, _ = X.shape
        nm = n * MAX_ATOMS
        x_mask = node_padding_mask.unsqueeze(-1)  # bs, n, 1
        pos_mask = x_mask.expand(-1, -1, MAX_ATOMS).reshape(bs, nm)  # bs, n*m

        # Create diagonal mask for edges
        diag_mask = torch.eye(n)
        diag_mask = ~diag_mask.type_as(E).bool()
        diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)

        # Get sigma embedding
        sigma_z = F.silu(self.sigma_map(sigma))

        # Embed inputs to obtain features, ready for transformer layers
        new_E = self.mlp_in_E(E)
        new_E = (new_E + new_E.transpose(1, 2)) / 2
        after_in = PlaceHolder(X=self.mlp_in_X(X), E=new_E, C=self.mlp_in_C(C, pos_mask)).mask(
            node_padding_mask
        )

        # 2. Go through transformer layers that update R, X, E (See XEyTransformerLayer)
        X, E, C = after_in.X, after_in.E, after_in.C
        for layer in self.tf_layers:
            X, E, C = layer(X, E, C, node_padding_mask, pos_mask)

        # 3. Pass the updated features through MLPs that project them back into the original feature space
        X = self.final_layer_X(X, sigma_z)
        E = self.final_layer_E(E, sigma_z)
        C = self.final_layer_C(C, pos_mask)

        # 4. Process the output (reshape C to original shape) and return
        E = E * diag_mask
        E = 1 / 2 * (E + torch.transpose(E, 1, 2))
        C = C.view(bs, n, MAX_ATOMS, 3)

        return X, E, C
