import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.utils import to_dense_adj, to_dense_batch


def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    return x * (1 + scale) + shift


def assert_correctly_masked(variable, node_mask):
    assert (
        variable * (1 - node_mask.long())
    ).abs().max().item() < 1e-4, "Variables not masked properly."


def encode_no_edge(E):
    assert len(E.shape) == 4
    if E.shape[-1] == 0:
        return E
    no_edge = torch.sum(E, dim=3) == 0
    no_edge_idx = E.shape[-1] - 2  # Second to last index
    no_edge_elt = E[:, :, :, no_edge_idx]
    no_edge_elt[no_edge] = 1
    E[:, :, :, no_edge_idx] = no_edge_elt
    diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
    E[diag] = 0
    return E


def to_dense(x, edge_index, edge_attr, batch, max_num_nodes, coordinates=None, coords_mask=None, cond_X=None, cond_C=None, cond_padding_mask=None, max_pharm=None):
    X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes)
    edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
    E = to_dense_adj(
        edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes
    )

    # Calculate edge mask from node mask
    # node_mask: [batch_size, max_nodes]
    # edge_mask: [batch_size, max_nodes, max_nodes]
    edge_mask = node_mask.unsqueeze(-1) & node_mask.unsqueeze(-2)

    # Make E symmetric by copying the upper triangle to the lower triangle
    E[:, :, :, 0] = torch.triu(E[:, :, :, 0]) + torch.triu(E[:, :, :, 0], 1).transpose(-1, -2)
    E = encode_no_edge(E)

    C = None
    coords_mask_dense = None
    if coordinates is not None:
        C, _ = to_dense_batch(x=coordinates, batch=batch, max_num_nodes=max_num_nodes)
        if coords_mask is not None:
            coords_mask_dense, _ = to_dense_batch(
                x=coords_mask, batch=batch, max_num_nodes=max_num_nodes
            )
    
    cond_batch = None
    if cond_X is not None:
        # Build batch vector from lengths
        batch_size = X.shape[0]
        cond_lengths = [max_pharm] * batch_size
        cond_batch = torch.cat([
            torch.full((l,), i, dtype=torch.long, device=cond_X.device)
            for i, l in enumerate(cond_lengths)
        ])
        cond_X, _ = to_dense_batch(x=cond_X, batch=cond_batch, max_num_nodes=max_pharm)
    if cond_C is not None:
        cond_C, _ = to_dense_batch(x=cond_C, batch=cond_batch, max_num_nodes=max_pharm)
    if cond_padding_mask is not None:
        cond_padding_mask, _ = to_dense_batch(x=cond_padding_mask, batch=cond_batch, max_num_nodes=max_pharm)

    return PlaceHolder(X=X, E=E, C=C), node_mask, edge_mask, coords_mask_dense, cond_X, cond_C, cond_padding_mask


def masked_softmax(x, mask, **kwargs):
    if not torch.any(mask):
        return x
    mask_expanded = ~mask
    if mask_expanded.dim() != x.dim():
        mask_expanded = mask_expanded.unsqueeze(-1)
    return torch.softmax(x + mask_expanded.float() * -1e9, **kwargs)


class PlaceHolder:
    def __init__(self, X, E, C, y=None):
        self.X = X
        self.E = E
        self.C = C
        self.y = y

    def type_as(self, x: torch.Tensor):
        """Changes the device and dtype of X, E, y."""
        self.X = self.X.type_as(x)
        self.E = self.E.type_as(x)
        self.C = self.C.type_as(x)
        if self.y is not None:
            self.y = self.y.type_as(x)
        return self

    def mask(self, node_mask, collapse=False):
        x_mask = node_mask.unsqueeze(-1)  # bs, n, 1
        e_mask1 = x_mask.unsqueeze(2)  # bs, n, 1, 1
        e_mask2 = x_mask.unsqueeze(1)  # bs, 1, n, 1

        if collapse:
            self.X = torch.argmax(self.X, dim=-1)
            self.E = torch.argmax(self.E, dim=-1)

            self.X[node_mask == 0] = -1
            self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = -1
        else:
            self.X = self.X * x_mask
            self.E = self.E * e_mask1 * e_mask2
            assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
        return self


def mean_pool_fragments(X, E, n, m):
    """Mean pool fragment features back to original shapes.

    Args:
        X (torch.Tensor): Node features tensor [bs, n*MAX_ATOMS, hidden_dims]
        E (torch.Tensor): Edge features tensor [bs, n*MAX_ATOMS, n*MAX_ATOMS, hidden_dims]
        n (int): Number of fragments
        m (int): Maximum number of atoms per fragment

    Returns:
        tuple: Contains:
            - X (torch.Tensor): Mean pooled node features [bs, n, hidden_dims]
            - E (torch.Tensor): Mean pooled edge features [bs, n, n, hidden_dims]
    """
    # Mean pool nodes
    X = X.view(X.shape[0], n, m, X.shape[-1])  # bs, n, m, hidden_dims
    X = X.mean(dim=2)  # bs, n, hidden_dims

    # Mean pool edges
    E = E.view(E.shape[0], n, m, n, m, E.shape[-1])  # bs, n, m, n, m, hidden_dims
    E = E.mean(dim=(2, 4))  # bs, n, n, hidden_dims

    return X, E
