from pathlib import Path

import torch
from rdkit import Chem
from rdkit.Chem import AllChem

from src.constants import (
    BUILDING_BLOCKS,
    N_BUILDING_BLOCKS,
    N_CENTERS,
    N_REACTIONS,
    REACTIONS,
)
from src.utils.indexing_utils import *


def generate_bb1_bb2_compatibilities(save_dir: Path):
    """Generate compatibility tensors for building blocks with reactions.

    Saves:
        bb1_compatibilities: torch.Tensor
            Compatibility tensor for reactions with the building block as the first reactant.
        bb2_compatibilities: torch.Tensor
            Compatibility tensor for reactions with the building block as the second reactant.

        All return types are of shape (N_REACTIONS * N_CENTERS * N_CENTERS + 2, N_BUILDING_BLOCKS + 1).
        This means that there is one for each reaction index. Intuitively, these two tensors describe
        the possible reactant 1 and 2's for each denoised reaction. In practice, when our model denoises
        a reaction, it uses the possible reactant 1 and 2's from the compatibility tensor to limit the
        reactant choices for the building blocks.
    """
    bb1_compatibilities = torch.ones(N_REACTIONS * N_CENTERS * N_CENTERS + 2, N_BUILDING_BLOCKS + 1)
    bb2_compatibilities = torch.ones(N_REACTIONS * N_CENTERS * N_CENTERS + 2, N_BUILDING_BLOCKS + 1)

    for reaction_idx in range(N_REACTIONS):
        for center1_idx in range(N_CENTERS):
            for center2_idx in range(N_CENTERS):
                flat_idx = (
                    reaction_idx * (N_CENTERS * N_CENTERS) + center1_idx * N_CENTERS + center2_idx
                )
                # Get reaction and reactant patterns
                rxn = AllChem.ReactionFromSmarts(idx_to_smarts(reaction_idx))
                reactants = rxn.GetReactants()

                # Get existing fragment info
                for idx in range(N_BUILDING_BLOCKS):
                    existing_smiles = idx_to_smiles(idx)
                    existing_mol = Chem.MolFromSmiles(existing_smiles)

                    # Check if existing fragment matches reaction pattern with correct center
                    existing_matches_r1 = existing_mol.GetSubstructMatches(reactants[0])
                    if center1_idx >= len(BUILDING_BLOCKS[idx_to_smiles(idx)]):
                        bb1_compatibilities[flat_idx, idx] = 0
                    else:
                        center1 = BUILDING_BLOCKS[idx_to_smiles(idx)][center1_idx]
                        if not any(center1 in match for match in existing_matches_r1):
                            bb1_compatibilities[flat_idx, idx] = 0

                    existing_matches_r2 = existing_mol.GetSubstructMatches(reactants[1])
                    if center2_idx >= len(BUILDING_BLOCKS[idx_to_smiles(idx)]):
                        bb2_compatibilities[flat_idx, idx] = 0
                    else:
                        center2 = BUILDING_BLOCKS[idx_to_smiles(idx)][center2_idx]
                        if not any(center2 in match for match in existing_matches_r2):
                            bb2_compatibilities[flat_idx, idx] = 0

    torch.save(bb1_compatibilities, save_dir / "bb1_compatibilities.pt")
    torch.save(bb2_compatibilities, save_dir / "bb2_compatibilities.pt")


def generate_reaction_in_out_compatibilities(save_dir: Path):
    """Generate compatibility tensors for reactions with building blocks.

    Saves:
        reaction_in_compatibilities: torch.Tensor
            Compatibility tensor for reactions with the building block as the second reactant.
        reaction_out_compatibilities: torch.Tensor
            Compatibility tensor for reactions with the building block as the first reactant.

        All return types are of shape (N_BUILDING_BLOCKS + 1, N_REACTIONS * N_CENTERS * N_CENTERS + 2).
        This means that there is one for each building block index. Intuitively, these two tensors describe
        the possible outgoing and incoming reactants for each building block. In practice, when our model
        denoises a building block, it uses the possible incoming and outgoing reactions from the compatibility
        tensor to limit the reaction choices.
    """
    reaction_in_compatibilities = torch.ones(
        N_BUILDING_BLOCKS + 1, N_REACTIONS * N_CENTERS * N_CENTERS + 2
    )
    reaction_out_compatibilities = torch.ones(
        N_BUILDING_BLOCKS + 1, N_REACTIONS * N_CENTERS * N_CENTERS + 2
    )

    for building_block_idx in range(N_BUILDING_BLOCKS):
        smiles = idx_to_smiles(building_block_idx)
        mol = Chem.MolFromSmiles(smiles)
        for reaction_idx in range(N_REACTIONS):
            rxn = AllChem.ReactionFromSmarts(idx_to_smarts(reaction_idx))
            reactants = rxn.GetReactants()
            reactant1 = reactants[0]
            reactant2 = reactants[1]
            reactant1_matches = mol.GetSubstructMatches(reactant1)
            for center1_idx, center1_atom_idx in enumerate(
                BUILDING_BLOCKS[idx_to_smiles(building_block_idx)]
            ):
                flat_idx = [
                    reaction_idx * (N_CENTERS * N_CENTERS)
                    + center1_idx * N_CENTERS
                    + center2_idx_iter
                    for center2_idx_iter in range(N_CENTERS)
                ]
                if not any(center1_atom_idx in match for match in reactant1_matches):
                    reaction_out_compatibilities[building_block_idx, flat_idx] = 0

            reactant2_matches = mol.GetSubstructMatches(reactant2)
            for center2_idx, center2_atom_idx in enumerate(
                BUILDING_BLOCKS[idx_to_smiles(building_block_idx)]
            ):
                flat_idx = [
                    reaction_idx * (N_CENTERS * N_CENTERS)
                    + center1_idx_iter * N_CENTERS
                    + center2_idx
                    for center1_idx_iter in range(N_CENTERS)
                ]
                if not any(center2_atom_idx in match for match in reactant2_matches):
                    reaction_in_compatibilities[building_block_idx, flat_idx] = 0

    torch.save(reaction_in_compatibilities, save_dir / "r_in_compatibilities.pt")
    torch.save(reaction_out_compatibilities, save_dir / "r_out_compatibilities.pt")


def main():
    save_dir = Path("vocabulary/compatibilities_plus20")
    save_dir.mkdir(parents=True, exist_ok=True)

    generate_bb1_bb2_compatibilities(save_dir)
    generate_reaction_in_out_compatibilities(save_dir)


if __name__ == "__main__":
    main()
