from typing import List, Tuple

import torch
from rdkit import Chem

from src.constants import (
    BUILDING_BLOCKS,
    COMPATIBILITY_MASKS,
    FRAGMENT_ATOMADJ,
    FRAGMENT_ATOMFEATS,
    FRAGMENT_BONDFEATS,
    FRAGMENT_MACCS,
    MAX_ATOMS,
    N_BOND_FEATURES,
    N_BUILDING_BLOCKS,
    N_REACTIONS,
    N_PHARM,
    REACTIONS,
)

N_CENTERS = max(
    len(centers) for centers in BUILDING_BLOCKS.values()
)  # Max number of reaction centers per fragment


def smiles_to_idx(smiles: str) -> int:
    """Convert a SMILES string to its index in the vocabulary."""
    return list(BUILDING_BLOCKS.keys()).index(smiles)


def idx_to_smiles(idx: int) -> str:
    """Convert a vocabulary index to its SMILES string."""
    return list(BUILDING_BLOCKS.keys())[idx]


def smarts_to_idx(smarts: str) -> int:
    """Convert a SMARTS pattern to its index in the vocabulary."""
    return list(REACTIONS.keys()).index(smarts)


def idx_to_smarts(idx: int) -> str:
    """Convert a vocabulary index to its SMARTS pattern."""
    return list(REACTIONS.keys())[idx]


def node_indices_to_onehot(indices: torch.Tensor) -> torch.Tensor:
    """Convert fragment indices to one-hot node matrix."""
    n_nodes = indices.shape[0]

    # Create node features [n_nodes, num_node_types + 1] using scatter
    X = torch.zeros((n_nodes, N_BUILDING_BLOCKS + 1))
    X.scatter_(1, indices.unsqueeze(1), 1)
    assert torch.all(X[:, -1] == 0)  # No masks

    return X


def pharm_indices_to_onehot(indices: torch.Tensor) -> torch.Tensor:
    """Convert pharmacophore indices to one-hot pharmacophore matrix."""
    indices = indices.to(torch.int64)
    n_nodes = indices.shape[0]

    # Create node features [n_nodes, num_node_types + 1] using scatter
    X = torch.zeros((n_nodes, N_PHARM + 1))
    X.scatter_(1, indices.unsqueeze(1), 1)
    assert torch.all(X[:, -1] == 0)  # No masks

    return X    


def onehot_to_node_indices(onehot: torch.Tensor) -> torch.Tensor:
    """Convert one-hot node matrix to fragment indices."""
    return onehot.argmax(dim=1)


def adj_matrix_to_edge_onehot(adj_matrix: torch.Tensor) -> torch.Tensor:
    """Convert adjacency matrix to one-hot edge matrix.

    Args:
        adj_matrix: [n_nodes, n_nodes] adjacency matrix (symmetric)

    Returns:
        E: [n_nodes, n_nodes, num_edge_types + 1] one-hot edge features
    """
    n_nodes = adj_matrix.shape[0]
    assert adj_matrix.shape == (n_nodes, n_nodes)

    # Create edge features [n_nodes, n_nodes, num_edge_types + 1]
    E = torch.zeros((n_nodes, n_nodes, N_REACTIONS + 1))
    # Set edge type 0 where there's a connection
    E[adj_matrix == 1, 0] = 1
    # Set edge type 1 where there's no connection
    E[adj_matrix == 0, 1] = 1
    assert torch.all(E[:, :, -1] == 0)  # No masks

    return E


def reaction_type_and_centers_to_onehot(
    adj_matrix: torch.Tensor, reactions: torch.Tensor
) -> torch.Tensor:
    """Convert adjacency matrix to 3D matrix with one-hot reaction vectors.

    Args:
        adj_matrix: [n_nodes, n_nodes] adjacency matrix
        reactions: [n_reactions, 5] tensor containing (reaction_idx, node1, node2, center1_idx, center2_idx)

    Returns:
        onehot: [n_nodes, n_nodes, n_reactions * n_centers * n_centers + 2] matrix
               where last two dims are for no-edge and mask
    """

    n_nodes = adj_matrix.shape[0]
    total_size = N_REACTIONS * N_CENTERS * N_CENTERS + 2  # +2 for no-edge and mask
    onehot = torch.zeros((n_nodes, n_nodes, total_size), device=adj_matrix.device)

    # Set reaction features for each reaction
    for reaction in reactions:
        reaction_idx, node1, node2, center1_idx, center2_idx = reaction.long()
        # Calculate one-hot index for the reaction
        flat_idx = reaction_idx * (N_CENTERS * N_CENTERS) + center1_idx * N_CENTERS + center2_idx

        # Set reaction features symmetrically
        onehot[node1, node2, flat_idx] = 1
        onehot[node2, node1, flat_idx] = 1

    # Set no-edge features where there are no edges and no reactions
    no_edge_mask = adj_matrix == 0
    onehot[..., -2][no_edge_mask] = 1

    # Last index remains 0 for no masking
    return onehot


def onehot_to_reaction_type_and_centers(onehot: torch.Tensor) -> torch.Tensor:
    """Convert one-hot reaction vector to tensor of reaction information.

    Args:
        onehot: [n_nodes, n_nodes, n_reactions * n_centers * n_centers + 2] matrix
               where last two dims are for no-edge and mask

    Returns:
        torch.Tensor: [n_reactions, 5] tensor containing
        (reaction_idx, node1, node2, center1_idx, center2_idx)
        ordered by node1 index, with duplicates removed
    """
    # Get indices of non-zero elements (excluding last 2 dimensions for no-edge and mask)
    nonzero = torch.nonzero(onehot[..., :-2])

    # Only keep reactions where node1 < node2 to avoid duplicates
    mask = nonzero[:, 0] < nonzero[:, 1]
    nonzero = nonzero[mask]

    # Sort by new node index
    nonzero = nonzero[torch.argsort(nonzero[:, 1])]

    # Convert flat indices to reaction and center indices
    flat_idx = nonzero[:, 2]
    reaction_idx = flat_idx // (N_CENTERS * N_CENTERS)
    remainder = flat_idx % (N_CENTERS * N_CENTERS)
    center1_idx = remainder // N_CENTERS
    center2_idx = remainder % N_CENTERS

    # Stack all indices into single tensor
    reactions = torch.stack(
        [reaction_idx, nonzero[:, 0], nonzero[:, 1], center1_idx, center2_idx],  # node1  # node2
        dim=1,
    )

    return reactions


def remove_graph_padding(X, E, length):
    """Remove padding from node and edge tensors."""
    X_trimmed = X[:length]
    E_trimmed = E[:length, :length]
    return X_trimmed, E_trimmed


def remove_coords_padding(coords, length):
    """Remove padding from coordinates tensor."""
    coords_trimmed = coords[:length]
    return coords_trimmed


def padding_mask(X, E, lengths):
    """Create padding mask for node and edge tensors.

    Args:
        X: Node features tensor of shape [batch_size, max_nodes, node_features]
        E: Edge features tensor of shape [batch_size, max_nodes, max_nodes, edge_features]
        lengths: Tensor of shape [batch_size] containing actual number of nodes for each graph

    Returns:
        Tuple of node_mask and edge_mask tensors with shapes:
        node_mask: [batch_size, max_nodes]
        edge_mask: [batch_size, max_nodes, max_nodes]
    """
    batch_size = X.shape[0]
    max_nodes = X.shape[1]
    node_mask = torch.arange(max_nodes, device=X.device)[None, :] < lengths[:, None]
    edge_mask = node_mask.unsqueeze(2) & node_mask.unsqueeze(1)

    return node_mask, edge_mask


def get_partial_maccs_keys(X_indices: torch.Tensor) -> torch.Tensor:
    """Get MACCS keys for a batch of molecules."""
    maccs_keys = FRAGMENT_MACCS.to(X_indices.device)[X_indices]  # [batch_size, n_fragments, 166]
    return maccs_keys


def get_partial_atom_features(X_indices: torch.Tensor) -> torch.Tensor:
    """Get atom features for a batch of molecules."""
    atom_features = FRAGMENT_ATOMFEATS.to(X_indices.device)[
        X_indices
    ]  # [batch_size, n_fragments, MAX_ATOMS, 6]
    return atom_features


def get_partial_bond_features(X_indices, mode="adj"):
    """Constructs a batched adjacency matrix for the full atom graph from fragment adjacency matrices.

    Args:
        X_indices: Fragment indices tensor [BS, n_frags]
        mode: String indicating which features to use - must be "adj" or "feats"

    Returns:
        Adjacency matrix [BS, n_frags*max_atoms, n_frags*max_atoms, 5] where last dim is onehot:
        [single, double, triple, aromatic, is_masked]
    """
    if mode not in ["adj", "feats"]:
        raise ValueError("mode must be either 'adj' or 'feats'")

    batch_size = X_indices.shape[0]
    n_frags = X_indices.shape[1]
    total_atoms = n_frags * MAX_ATOMS

    # Initialize output adjacency matrix to all masked (onehot index 4)
    if mode == "adj":
        adj = torch.zeros((batch_size, total_atoms, total_atoms), device=X_indices.device)
    elif mode == "feats":
        adj = torch.zeros(
            (batch_size, total_atoms, total_atoms, N_BOND_FEATURES), device=X_indices.device
        )
        adj[..., -1] = 1  # Set is_masked=1 for all entries initially

    # Place fragment adjacency matrices along diagonal blocks
    for b in range(batch_size):
        for i in range(n_frags):
            frag_idx = X_indices[b, i]
            if frag_idx != N_BUILDING_BLOCKS:  # Not masked fragment
                start_idx = i * MAX_ATOMS
                end_idx = (i + 1) * MAX_ATOMS

                # Copy the fragment's adjacency matrix directly
                if mode == "adj":
                    adj[b, start_idx:end_idx, start_idx:end_idx] = FRAGMENT_ATOMADJ[frag_idx]
                else:  # mode == "feats"
                    adj[b, start_idx:end_idx, start_idx:end_idx] = FRAGMENT_BONDFEATS[frag_idx]

    return adj


def node_to_atom_padding_mask(node_padding_mask):
    """Convert node mask to atom mask."""
    bs, n = node_padding_mask.shape
    atom_padding_mask = node_padding_mask.unsqueeze(-1)  # bs, n, 1
    atom_padding_mask = atom_padding_mask.expand(-1, -1, MAX_ATOMS)  # bs, n, MAX_ATOMS
    atom_padding_mask = atom_padding_mask.reshape(bs, n * MAX_ATOMS)  # bs, n*MAX_ATOMS
    return atom_padding_mask


def perfrag_atom_padding_mask(X_indices):
    """Get per-fragment atom padding mask."""
    bs, n_frags = X_indices.shape
    perfrag_atom_padding_mask = torch.zeros(
        (bs, n_frags * MAX_ATOMS), dtype=X_indices.dtype, device=X_indices.device
    )
    for b in range(bs):
        for i in range(n_frags):
            frag_idx = X_indices[b, i]
            start_idx = i * MAX_ATOMS
            end_idx = (i + 1) * MAX_ATOMS
            if frag_idx != N_BUILDING_BLOCKS:  # Not masked fragment
                n_atoms = Chem.MolFromSmiles(idx_to_smiles(frag_idx)).GetNumAtoms()
                perfrag_atom_padding_mask[b, start_idx : start_idx + n_atoms] = 1
            else:  # Masked fragment - all positions are atoms
                perfrag_atom_padding_mask[b, start_idx:end_idx] = 1
    return perfrag_atom_padding_mask


def get_compatibility_masks(X, E, limit_edges=True):
    """Get compatibility mask for a batch of molecules.
    X is a onehot tensor of shape [bs, n, n_building_blocks + 1]
    E is a onehot tensor of shape [bs, n, n, n_reactions + 1]
    """
    bb1_compat, bb2_compat, r_out_compat, r_in_compat = COMPATIBILITY_MASKS
    bb1_compat, bb2_compat, r_out_compat, r_in_compat = (
        bb1_compat.to(X.device),
        bb2_compat.to(X.device),
        r_out_compat.to(X.device),
        r_in_compat.to(X.device),
    )
    bs, n = X.shape[0], X.shape[1]
    X_indices = X.argmax(dim=-1)
    E_indices = E.argmax(dim=-1)
    compatibility_mask_X = torch.ones_like(X, device=X.device)
    compatibility_mask_E = torch.ones_like(E, device=E.device)

    for b in range(bs):
        for i in range(n):
            for j in range(i + 1, n):  # Only upper triangle
                if E[b, i, j, -1] != 1:
                    if (compatibility_mask_X[b, i] * bb1_compat[E_indices[b, i, j]]).sum() > 1:
                        compatibility_mask_X[b, i] = (
                            compatibility_mask_X[b, i] * bb1_compat[E_indices[b, i, j]]
                        )
                    if (compatibility_mask_X[b, j] * bb2_compat[E_indices[b, i, j]]).sum() > 1:
                        compatibility_mask_X[b, j] = (
                            compatibility_mask_X[b, j] * bb2_compat[E_indices[b, i, j]]
                        )

    if limit_edges:
        for b in range(bs):
            for i in range(n):
                if X[b, i, -1] != 1:
                    # If we have unmasked a building block, then all possible outgoing reactions from that
                    # building block index must be constrained by the compatibility tensor.
                    for j in range(i + 1, n):
                        if (
                            compatibility_mask_E[b, i, j] * r_out_compat[X_indices[b, i]]
                        ).sum() > 1:
                            compatibility_mask_E[b, i, j] = (
                                compatibility_mask_E[b, i, j] * r_out_compat[X_indices[b, i]]
                            )
                            compatibility_mask_E[b, j, i] = compatibility_mask_E[b, i, j]

    assert torch.all(compatibility_mask_X.sum(dim=-1) > 1)
    assert torch.all(compatibility_mask_E.sum(dim=-1) > 1)
    return compatibility_mask_X.bool(), compatibility_mask_E.bool()
