import json
from pathlib import Path

import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import AllChem

# Get project root directory
ROOT_DIR = Path(__file__).parent.parent

# Load configs once at module level
with open(ROOT_DIR / "vocabulary/building_blocks.json") as f:
    BUILDING_BLOCKS = json.load(f)
    N_BUILDING_BLOCKS = len(BUILDING_BLOCKS)
    # Get max number of atoms across all building blocks
    MAX_ATOMS = max(Chem.MolFromSmiles(bb).GetNumAtoms() for bb in BUILDING_BLOCKS.keys())

    # Generate MACCS fingerprints for each building block
    FRAGMENT_MACCS = torch.zeros((N_BUILDING_BLOCKS + 1, 167), dtype=torch.float32)
    for i, smiles in enumerate(BUILDING_BLOCKS.keys()):
        mol = Chem.MolFromSmiles(smiles)
        fp = AllChem.GetMACCSKeysFingerprint(mol)
        arr = np.zeros((167,))
        AllChem.DataStructs.ConvertToNumpyArray(fp, arr)
        FRAGMENT_MACCS[i] = torch.tensor(arr, dtype=torch.float32)

    # Generate atom features for each building block
    ATOM_TYPES = ["C", "N", "O", "B", "F", "Cl", "Br", "S"]
    N_ATOM_FEATURES = (
        len(ATOM_TYPES) + 3
    )  # one-hot atom type, in ring, is_reaction_center, is_masked
    FRAGMENT_ATOMFEATS = torch.zeros(
        (N_BUILDING_BLOCKS + 1, MAX_ATOMS, N_ATOM_FEATURES), dtype=torch.float32
    )
    for i, smiles in enumerate(BUILDING_BLOCKS.keys()):
        mol = Chem.MolFromSmiles(smiles)

        for j, atom in enumerate(mol.GetAtoms()):
            features = []
            # One-hot encode atom type
            atom_type = atom.GetSymbol()
            atom_onehot = [1.0 if atom_type == t else 0.0 for t in ATOM_TYPES]
            features.extend(atom_onehot)
            features.append(int(atom.IsInRing()))
            # Is the atom a reaction center?
            if j in BUILDING_BLOCKS[smiles]:
                features.append(1.0)
            else:
                features.append(0.0)
            features.append(0.0)  # is_masked feature
            FRAGMENT_ATOMFEATS[i, j] = torch.tensor(features, dtype=torch.float32)

    # Set is_masked=1 for the last item in the stack (masked token)
    FRAGMENT_ATOMFEATS[-1, :, -1] = 1.0

    # Generate adjacency matrices for each building block
    N_BOND_FEATURES = 5  # single, double, triple, aromatic, is_masked
    FRAGMENT_BONDFEATS = torch.zeros(
        (N_BUILDING_BLOCKS + 1, MAX_ATOMS, MAX_ATOMS, N_BOND_FEATURES), dtype=torch.float32
    )
    FRAGMENT_ATOMADJ = torch.zeros(
        (N_BUILDING_BLOCKS + 1, MAX_ATOMS, MAX_ATOMS), dtype=torch.float32
    )
    for i, smiles in enumerate(BUILDING_BLOCKS.keys()):
        mol = Chem.MolFromSmiles(smiles)
        bond_feats = torch.zeros((MAX_ATOMS, MAX_ATOMS, N_BOND_FEATURES))
        atom_adj = torch.zeros((MAX_ATOMS, MAX_ATOMS))
        for bond in mol.GetBonds():
            a1 = bond.GetBeginAtomIdx()
            a2 = bond.GetEndAtomIdx()
            bond_type = bond.GetBondTypeAsDouble()
            # Set adjacency to 1 for any bond
            atom_adj[a1, a2] = 1
            atom_adj[a2, a1] = 1
            # Convert bond type to one-hot: [single, double, triple, aromatic, is_masked]
            if bond_type == 1.0:  # Single bond
                bond_feats[a1, a2, 0] = 1
                bond_feats[a2, a1, 0] = 1
            elif bond_type == 2.0:  # Double bond
                bond_feats[a1, a2, 1] = 1
                bond_feats[a2, a1, 1] = 1
            elif bond_type == 3.0:  # Triple bond
                bond_feats[a1, a2, 2] = 1
                bond_feats[a2, a1, 2] = 1
            elif bond_type == 1.5:  # Aromatic bond
                bond_feats[a1, a2, 3] = 1
                bond_feats[a2, a1, 3] = 1
        FRAGMENT_BONDFEATS[i] = bond_feats
        FRAGMENT_ATOMADJ[i] = atom_adj

    # Set is_masked=1 for the last item in the bond features stack (masked token)
    FRAGMENT_BONDFEATS[-1, :, :, -1] = 1.0

with open(ROOT_DIR / "vocabulary/reactions.json") as f:
    REACTIONS = json.load(f)
    N_REACTIONS = len(REACTIONS)

# Max number of reaction centers per fragment
N_CENTERS = max(len(centers) for centers in BUILDING_BLOCKS.values())

DATASET_PRIORS_PATH = ROOT_DIR / "data/dataset_priors.json"
try:
    with open(DATASET_PRIORS_PATH) as f:
        DATASET_PRIORS = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
    DATASET_PRIORS = {}

COORDS_STD = 2.7962567753408374

COMPATIBILITY_MASKS = (
    torch.load(ROOT_DIR / "vocabulary/compatibilities/bb1_compatibilities.pt"),
    torch.load(ROOT_DIR / "vocabulary/compatibilities/bb2_compatibilities.pt"),
    torch.load(ROOT_DIR / "vocabulary/compatibilities/r_out_compatibilities.pt"),
    torch.load(ROOT_DIR / "vocabulary/compatibilities/r_in_compatibilities.pt"),
)

SMILES_PATH = ROOT_DIR / "vocabulary/smiles_train_full.txt"

with open(SMILES_PATH) as f:
    TRAIN_SMILES = f.readlines()

TRAIN_SMILES = [s.strip() for s in TRAIN_SMILES]

#Shepherd pharmacophores
N_PHARM = 8
MAX_PHARM = 40
