from typing import List, Tuple

import torch
from rdkit import Chem
from rdkit.Chem import AllChem

from src.api.data_structures import MoleculeFragmentGraph
from src.constants import BUILDING_BLOCKS
from src.utils.indexing_utils import idx_to_smarts, idx_to_smiles


def is_valid_smiles(smiles_or_mol):
    """Check if a SMILES string is valid and can be fully sanitized."""
    try:
        # Try to create and sanitize the molecule
        if isinstance(smiles_or_mol, str):
            smiles_or_mol = Chem.MolFromSmiles(smiles_or_mol)
        if smiles_or_mol is None:
            return False

        # Ensure full sanitization including kekulization
        Chem.SanitizeMol(smiles_or_mol)
        return True
    except:
        return False


def build_molecule(nodes: torch.Tensor, decoded_edges: torch.Tensor, smiles: bool = False) -> str:
    """Build a molecule from model outputs. Used in calculating pvalid.

    Args:
        nodes: Tensor of node indices [n_nodes]
        decoded_edges: Tensor of decoded model outputs [n_edges, 5] containing
                     (reaction_id, node1_order, node2_order, center1_idx, center2_idx).
                     These are not the same as actions for adding to MFG and must be converted.
        smiles: Whether to return a SMILES string or an RDKit molecule
    Returns:
        SMILES string of the molecule
    """
    try:
        nodes_list = nodes.tolist()
        edges_list = decoded_edges.tolist()
        mfg = MoleculeFragmentGraph()
        mfg.add_fragment(nodes_list[0])

        # This condition has to be in place because you can sometimes still build a molecule that has too many denoised reactions
        assert (
            len(nodes_list) == len(edges_list) + 1
        ), f"Mismatch in nodes and edges: {len(nodes_list)} != {len(edges_list)} + 1"
        for i, node in enumerate(nodes_list[1:]):
            edge = edges_list[i]
            action = [
                node,
                edge[0],  # reaction_idx
                edge[1],  # node1
                edge[3],  # center1_idx
                edge[4],  # center2_idx
            ]

            if is_valid_action(mfg, action):
                mfg.add_fragment(*action)
            else:
                print("invalid action:", action)
                return None

        return mfg.to_smiles() if smiles else mfg.to_mol()

    except Exception as e:
        print("Failed to build molecule:", e)
        return None


def is_valid_action(mfg: MoleculeFragmentGraph, action: list) -> bool:
    """Check if an action is valid for the current molecule fragment graph.

    Args:
        mfg: Current molecule fragment graph
        action: List of [new_fragment_global_id, reaction_id, existing_frag_idx, center1_idx, center2_idx]

    Returns:
        bool: Whether the action is valid
    """
    # Unpack action list
    new_frag_global_id = action[0]
    reaction_id = action[1]
    existing_frag_idx = action[2]
    center1_idx = action[3]
    center2_idx = action[4]
    new_smiles = idx_to_smiles(new_frag_global_id)

    # Check if reaction center is still available
    if not mfg.fragment_graph.nodes[existing_frag_idx]["rxn_center_available"][center1_idx]:
        return False

    # Get reaction and reactant patterns
    rxn = AllChem.ReactionFromSmarts(idx_to_smarts(reaction_id))
    reactants = rxn.GetReactants()

    # Get existing fragment info
    existing_smiles = mfg.fragment_graph.nodes[existing_frag_idx]["node"].smiles
    existing_mol = Chem.MolFromSmiles(existing_smiles)
    existing_center = BUILDING_BLOCKS[existing_smiles][center1_idx]

    # Check if existing fragment matches reaction pattern with correct center
    existing_matches = existing_mol.GetSubstructMatches(reactants[0])
    if not any(existing_center in match for match in existing_matches):
        return False

    # Get new fragment info
    new_mol = Chem.MolFromSmiles(new_smiles)
    new_center = BUILDING_BLOCKS[new_smiles][center2_idx]

    # Check if new fragment matches reaction pattern with correct center
    new_matches = new_mol.GetSubstructMatches(reactants[1])
    if not any(new_center in match for match in new_matches):
        return False

    return True
