import math

import torch
import torch.nn as nn
import torch_geometric
from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch_geometric.utils import to_dense_adj, to_dense_batch

from src.constants import MAX_ATOMS, N_BUILDING_BLOCKS, N_CENTERS, N_REACTIONS
from src.models.layers import SimplePositionalEmbedding, TimestepEmbedding
from src.models.utils import (
    PlaceHolder,
    assert_correctly_masked,
    encode_no_edge,
    masked_softmax,
    modulate,
    to_dense,
)
from src.utils.indexing_utils import get_partial_maccs_keys


class XEyTransformerLayer(nn.Module):
    """Transformer that updates node and edge features
    d_x: node features
    d_e: edge features
    n_head: the number of heads in the multi_head_attention
    dim_feedforward: the dimension of the feedforward network model after self-attention
    dropout: dropout probablility. 0 to disable
    layer_norm_eps: eps value in layer normalizations.
    """

    def __init__(
        self,
        dx: int,
        de: int,
        n_head: int,
        dim_ffX: int = 2048,
        dim_ffE: int = 128,
        cond_dim: int = 64,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5,
        device=None,
        dtype=None,
    ) -> None:
        kw = {"device": device, "dtype": dtype}
        super().__init__()

        self.cond_dim = cond_dim
        self.self_attn = NodeEdgeBlock(dx, de, n_head, **kw)
        self.linX1 = Linear(dx, dim_ffX, **kw)
        self.linX2 = Linear(dim_ffX, dx, **kw)
        self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
        self.dropoutX1 = Dropout(dropout)
        self.dropoutX2 = Dropout(dropout)
        self.dropoutX3 = Dropout(dropout)

        self.linE1 = Linear(de, dim_ffE, **kw)
        self.linE2 = Linear(dim_ffE, de, **kw)
        self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw)
        self.dropoutE1 = Dropout(dropout)
        self.dropoutE2 = Dropout(dropout)
        self.dropoutE3 = Dropout(dropout)
        self.adaLN_modulation = nn.Linear(cond_dim, 4 * dx, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

        self.activation = F.relu

    def forward(self, X_C: Tensor, E: Tensor, c: Tensor, node_padding_mask: Tensor):
        """Pass the input through the encoder layer.
        X_C: (bs, n, d + MAX_ATOMS * 3)
        E: (bs, n, n, d)
        node_padding_mask: (bs, n) Mask for the src keys per batch (optional)
        Output: newX, newE with the same shape.
        """

        shift_pre, scale_pre, shift_post, scale_post = self.adaLN_modulation(c).chunk(4, dim=1)
        X_C = modulate(X_C, shift_pre.unsqueeze(1), scale_pre.unsqueeze(1))
        newX_C, newE = self.self_attn(X_C, E, node_padding_mask=node_padding_mask)
        newX_C = modulate(newX_C, shift_post.unsqueeze(1), scale_post.unsqueeze(1))

        newX_C_d = self.dropoutX1(newX_C)
        X_C = self.normX1(X_C + newX_C_d)

        newE_d = self.dropoutE1(newE)
        E = self.normE1(E + newE_d)

        ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X_C))))
        ff_outputX = self.dropoutX3(ff_outputX)
        X_C = self.normX2(X_C + ff_outputX)

        ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E))))
        ff_outputE = self.dropoutE3(ff_outputE)
        E = self.normE2(E + ff_outputE)

        return X_C, E


class NodeEdgeBlock(nn.Module):
    """Self attention layer that also updates the representations on the edges."""

    def __init__(self, dx, de, n_head, **kwargs):
        super().__init__()
        assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}"
        self.dx = dx
        self.de = de
        self.df = int(dx / n_head)
        self.n_head = n_head

        self.q = Linear(dx, dx)
        self.k = Linear(dx, dx)
        self.v = Linear(dx, dx)

        # FiLM E to X_C
        self.e_add = Linear(de, dx)
        self.e_mul = Linear(de, dx)

        # Output layers
        self.x_out = Linear(dx, dx)
        self.e_out = Linear(dx, de)

        self.pe_x = SimplePositionalEmbedding(dim=1, max_len=32, d_model=dx)
        self.pe_e = SimplePositionalEmbedding(dim=2, max_len=32, d_model=de)

    def forward(self, X_C, E, node_padding_mask):
        """
        :param X_C: bs, n, d + MAX_ATOMS * 3        node features + coordinates
        :param E: bs, n, n, d     edge features
        :param node_padding_mask: bs, n
        :return: newX, newE with updated representations
        """
        # Create joint embedding of X and C

        X_C = self.pe_x(X_C)  # Apply positional encoding to joint embedding
        E = self.pe_e(E)

        bs, n, _ = X_C.shape
        x_mask = node_padding_mask.unsqueeze(-1)  # bs, n, 1
        e_mask1 = x_mask.unsqueeze(2)  # bs, n, 1, 1
        e_mask2 = x_mask.unsqueeze(1)  # bs, 1, n, 1

        # 1. Map X to keys and queries
        Q = self.q(X_C) * x_mask  # (bs, n, dx)
        K = self.k(X_C) * x_mask  # (bs, n, dx)
        assert_correctly_masked(Q, x_mask)
        # 2. Reshape to (bs, n, n_head, df) with dx = n_head * df

        Q = Q.reshape((Q.size(0), Q.size(1), self.n_head, self.df))
        K = K.reshape((K.size(0), K.size(1), self.n_head, self.df))
        Q = Q.unsqueeze(2)  # (bs, 1, n, n_head, df)
        K = K.unsqueeze(1)  # (bs, n, 1, n head, df)

        # Compute unnormalized attentions. Y is (bs, n, n, n_head, df)
        Y = Q * K
        Y = Y / math.sqrt(Y.size(-1))
        assert_correctly_masked(Y, (e_mask1 * e_mask2).unsqueeze(-1))

        E1 = self.e_mul(E) * e_mask1 * e_mask2  # bs, n, n, dx
        E1 = E1.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df))

        E2 = self.e_add(E) * e_mask1 * e_mask2  # bs, n, n, dx
        E2 = E2.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df))

        # Incorporate edge features to the self attention scores.
        Y = Y * (E1 + 1) + E2  # (bs, n, n, n_head, df)

        # Incorporate y to E
        newE = Y.flatten(start_dim=3)  # bs, n, n, dx

        # Output E
        newE = self.e_out(newE) * e_mask1 * e_mask2  # bs, n, n, de
        assert_correctly_masked(newE, e_mask1 * e_mask2)

        # Compute attentions. attn is still (bs, n, n, n_head, df)
        softmax_mask = e_mask2.expand(-1, n, -1, self.n_head)  # bs, 1, n, 1
        attn = masked_softmax(Y, softmax_mask, dim=2)  # bs, n, n, n_head

        V = self.v(X_C) * x_mask  # bs, n, dx
        V = V.reshape((V.size(0), V.size(1), self.n_head, self.df))
        V = V.unsqueeze(1)  # (bs, 1, n, n_head, df)

        # Compute weighted values
        weighted_V = attn * V
        weighted_V = weighted_V.sum(dim=2)

        # Send output to input dim
        weighted_V = weighted_V.flatten(start_dim=2)  # bs, n, dx

        newX_C = weighted_V

        # Output X
        newX_C = self.x_out(newX_C) * x_mask
        assert_correctly_masked(newX_C, x_mask)

        return newX_C, newE


class GTFinalLayer(nn.Module):
    def __init__(self, hidden_size, out_channels, cond_dim):
        """
        Args:
            hidden_size: Dimension of hidden layer
            out_channels: Number of output channels
            cond_dim: Dimension of conditioning vector
        """
        super().__init__()
        self.norm_final = LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.linear.weight.data.zero_()
        self.linear.bias.data.zero_()

        self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x, c):
        # Get modulation parameters and reshape based on x's dimensions
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)

        # Add necessary number of dimensions for broadcasting
        for _ in range(x.dim() - shift.dim()):
            shift = shift.unsqueeze(1)
            scale = scale.unsqueeze(1)

        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class GraphTransformer(nn.Module):
    """
    n_layers : int -- number of layers
    dims : dict -- contains dimensions for each feature type
    """

    def __init__(self, config):
        super().__init__()
        self.self_conditioning = config.self_conditioning
        self.n_layers = config.model.n_layers
        self.dim_X = N_BUILDING_BLOCKS + 1
        self.dim_E = N_REACTIONS * N_CENTERS * N_CENTERS + 2
        self.hidden_mlp_dims = config.model.hidden_mlp_dims
        self.hidden_dims = config.model.hidden_dims
        self.act_fn_in = config.model.act_fn_in
        self.act_fn_out = config.model.act_fn_out
        self.cond_dim = config.model.cond_dim
        self.sigma_map = TimestepEmbedding(config.model.cond_dim)
        self.final_layer_X = GTFinalLayer(self.hidden_dims.dx, self.dim_X, self.cond_dim)
        self.final_layer_E = GTFinalLayer(self.hidden_dims.de, self.dim_E, self.cond_dim)
        self.final_layer_C = GTFinalLayer(self.hidden_dims.dx, MAX_ATOMS * 3, self.cond_dim)
        activation_functions = {"ReLU": nn.ReLU(), "LeakyReLU": nn.LeakyReLU()}

        self.act_fn_in = activation_functions.get(self.act_fn_in, self.act_fn_in)
        self.act_fn_out = activation_functions.get(self.act_fn_out, self.act_fn_out)

        input_dim_X_C = 167 + self.dim_X + MAX_ATOMS * 3
        input_dim_E = self.dim_E
        if self.self_conditioning:
            input_dim_X_C = 167 + (2 * self.dim_X) + (2 * MAX_ATOMS * 3)
            input_dim_E = 2 * self.dim_E

        self.mlp_in_X_C = nn.Sequential(
            nn.Linear(input_dim_X_C, self.hidden_mlp_dims.X),
            self.act_fn_in,
            nn.Linear(self.hidden_mlp_dims.X, self.hidden_dims.dx),
            self.act_fn_in,
        )

        self.mlp_in_E = nn.Sequential(
            nn.Linear(input_dim_E, self.hidden_mlp_dims.E),
            self.act_fn_in,
            nn.Linear(self.hidden_mlp_dims.E, self.hidden_dims.de),
            self.act_fn_in,
        )

        self.tf_layers = nn.ModuleList(
            [
                XEyTransformerLayer(
                    dx=self.hidden_dims.dx,
                    de=self.hidden_dims.de,
                    n_head=self.hidden_dims.n_head,
                    dim_ffX=self.hidden_dims.dim_ffX,
                    dim_ffE=self.hidden_dims.dim_ffE,
                    cond_dim=self.cond_dim,
                )
                for i in range(self.n_layers)
            ]
        )

    def forward(self, X, E, C, node_padding_mask, sigma):
        # Timestep embedding
        c = F.silu(self.sigma_map(sigma))

        if self.self_conditioning:
            X, X_cond = X.chunk(2, dim=-1)

        bs, n = X.shape[0], X.shape[1]
        X_indices = X.argmax(dim=-1)
        X_keys = get_partial_maccs_keys(X_indices)
        if self.self_conditioning:
            X = torch.cat([X_keys, X, X_cond], dim=-1)
        else:
            X = torch.cat([X_keys, X], dim=-1)
        # flatten C to [bs, n, MAX_ATOMS * 3] for input to
        C = C.view(C.shape[0], C.shape[1], -1)

        # Concatenate X and C
        X_C = torch.cat([X, C], dim=-1)

        # Create diagonal mask for edges
        diag_mask = torch.eye(n)
        diag_mask = ~diag_mask.type_as(E).bool()
        diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)

        # First layer for edge features
        new_E = self.mlp_in_E(E)
        new_E = (new_E + new_E.transpose(1, 2)) / 2
        new_C = C

        after_in = PlaceHolder(X=self.mlp_in_X_C(X_C), E=new_E, C=new_C).mask(node_padding_mask)

        # Extract features from after_in
        X_C, E, C = after_in.X, after_in.E, after_in.C

        # Process through transformer layers
        for layer in self.tf_layers:
            X_C, E = layer(X_C, E, c, node_padding_mask)

        # Final layers for X, E, and C. Node + Coordinates are both calculated from X_C joint
        X = self.final_layer_X(X_C, c)
        E = self.final_layer_E(E, c)
        C = self.final_layer_C(X_C, c)
        C = C.view(C.shape[0], C.shape[1], MAX_ATOMS, -1)

        E = E * diag_mask

        # Symmetrize E
        E = 1 / 2 * (E + torch.transpose(E, 1, 2))

        # Return final outputs

        return X, E, C
