
import torch
import torch.nn.functional as F

# === RDKit libraries ===
from rdkit import Chem
from rdkit.Chem.rdchem import ChiralType
from rdkit.Chem.rdchem import BondType as BT

chirality = {ChiralType.CHI_TETRAHEDRAL_CW: -1.,
             ChiralType.CHI_TETRAHEDRAL_CCW: 1.,
             ChiralType.CHI_UNSPECIFIED: 0,
             ChiralType.CHI_OTHER: 0}
                                 
PATT = Chem.MolFromSmarts('[!$([NH]!@C(=O))&!D1&!$(*#*)]-&!@[!$([NH]!@C(=O))&!D1&!$(*#*)]')

bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

BODN_TYPES = {
    0: Chem.rdchem.BondType.SINGLE, 
    1: Chem.rdchem.BondType.DOUBLE,
    2: Chem.rdchem.BondType.TRIPLE,
    3: Chem.rdchem.BondType.AROMATIC}

BODN_TYPE_DEGREE = {
    Chem.rdchem.BondType.SINGLE: 1, 
    Chem.rdchem.BondType.DOUBLE: 2,
    Chem.rdchem.BondType.TRIPLE: 3}
dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]')


def one_k_encoding(value, choices):
    """
    Creates a one-hot encoding with an extra category for uncommon values.
    :param value: The value for which the encoding should be one.
    :param choices: A list of possible values.
    :return: A one-hot encoding of the :code:`value` in a list of length :code:`len(choices) + 1`.
             If :code:`value` is not in :code:`choices`, then the final element in the encoding is 1.
    """
    encoding = [0] * (len(choices) + 1)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1
    return encoding

def featurize_mol(data, types):
    """
    Part of the featurisation code taken from GeoMol https://github.com/PattanaikL/GeoMol & Torsional Diffusion https://github.com/gcorso/torsional-diffusion
    """
    mol = data.mol
    N = mol.GetNumAtoms()
    atom_type_idx = []
    atomic_number = []
    atom_features = []
    chiral_tag = []
    ring = mol.GetRingInfo()
    for i, atom in enumerate(mol.GetAtoms()):
        atom_type_idx.append(types[atom.GetSymbol()])
        chiral_tag.append(chirality[atom.GetChiralTag()])
        atomic_number.append(atom.GetAtomicNum())
        atom_features.extend([atom.GetAtomicNum(),
                              1 if atom.GetIsAromatic() else 0])
        atom_features.extend(one_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6]))
        atom_features.extend(one_k_encoding(atom.GetHybridization(), [
            Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D,
            Chem.rdchem.HybridizationType.SP3D2]))
        atom_features.extend(one_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]))
        atom_features.extend(one_k_encoding(atom.GetFormalCharge(), [-2, -1, 0, 1, 2]))
        atom_features.extend([int(ring.IsAtomInRingOfSize(i, 3)),
                              int(ring.IsAtomInRingOfSize(i, 4)),
                              int(ring.IsAtomInRingOfSize(i, 5)),
                              int(ring.IsAtomInRingOfSize(i, 6)),
                              int(ring.IsAtomInRingOfSize(i, 7)),
                              int(ring.IsAtomInRingOfSize(i, 8))])
        atom_features.extend(one_k_encoding(int(ring.NumAtomRings(i)), [0, 1, 2, 3]))
        atom_features.append(atom.GetTotalNumHs())
        atom_features.append(atom.GetNumRadicalElectrons())
        atom_features.append(int(atom.IsInRing()))
        atom_features.extend(one_k_encoding(chirality[atom.GetChiralTag()], [-1, 0, 1]))
        

    z = torch.tensor(atomic_number, dtype=torch.long)

    row, col, edge_type = [], [], []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        edge_type += 2 * [bonds[bond.GetBondType()]]

    edge_index = torch.tensor([row, col], dtype=torch.long)
    edge_type = torch.tensor(edge_type, dtype=torch.long)
    edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float)

    x1 = F.one_hot(torch.tensor(atom_type_idx), num_classes=len(types))
    x2 = torch.tensor(atom_features).view(N, -1)
    x = torch.cat([x1.to(torch.float), x2], dim=-1)

    data.node_attr = x
    data.edge_index = edge_index
    data.edge_attr = edge_attr
    data.z = z
    data.chiral_tag = torch.tensor(chiral_tag)
    return data


def add_chiral_edge_order_feature(data, mol):
    """
    Augments data.edge_attr with a new one-hot chiral position encoding (shape: [E, 4]).
    For each chiral atom (data.chiral_tag[i] != 0), find neighbor indices and encode the order (0-3).
    """
    edge_index = data.edge_index  # [2, E]
    E = edge_index.size(1)
    N = data.chiral_tag.size(0)

    # Init new feature tensor: (E, 4)
    C = torch.zeros((E, 4), dtype=torch.float)

    # Step 1: build edge lookup: (src, dst) -> edge_idx
    edge_map = {}  # (src, dst): edge_idx
    for eid in range(E):
        src = edge_index[0, eid].item()
        dst = edge_index[1, eid].item()
        edge_map[(src, dst)] = eid

    # Step 2: iterate over atoms, find chiral centers
    for atom in mol.GetAtoms():
        idx = atom.GetIdx()
        if data.chiral_tag[idx] == 0:
            continue  # skip non-chiral atoms

        # Get neighbors and sort for consistent ordering
        neighbors = sorted([nbr.GetIdx() for nbr in atom.GetNeighbors()])

        # Assign position: 0/1/2/3 → neighbor atom
        for pos, nbr_idx in enumerate(neighbors):
            # Find the edge from chiral center → neighbor
            edge_key = (idx, nbr_idx)
            if edge_key in edge_map:
                eid = edge_map[edge_key]
                C[eid, pos] = 1.0  # one-hot: neighbor at pos

    # Step 3: append or concat into edge_attr
    data.edge_attr = torch.cat([data.edge_attr, C], dim=1)  # shape: [E, original + 4]
    return data

# allowable multiple choice node and edge features
from collections import defaultdict
from typing import Callable, Tuple

import datamol as dm
import torch
from datamol.types import Mol
from rdkit import Chem
from torch_geometric.data import Data

from utils.commons.covmat import build_conformer
from utils.commons.utils import atom_to_feature_vector, compute_edge_index, get_chiral_tensors


def get_mol_from_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    return mol


def cache_decorator(func: Callable):
    """Decorator to handle caching logic."""

    def wrapper(self, smiles: str, *args, **kwargs):
        cache_key = func.__name__
        if smiles in self.cache and cache_key in self.cache[smiles]:
            return self.cache[smiles][cache_key]
        result = func(self, smiles, *args, **kwargs)
        self.cache[smiles][cache_key] = result
        return result

    return wrapper


class MoleculeFeaturizer:
    """A Featurizer Class for Molecules.
    - Give smiles, get mol objects, atom features, bond features, etc.
    - Smiles-based Caching to avoid recomputation.
    """

    def __init__(self):
        # smiles based cache
        self.cache = defaultdict(dict)

    def get_mol(self, smiles: str) -> Mol:
        return dm.to_mol(smiles, remove_hs=False, ordered=True)

    @cache_decorator
    def get_atom_features(self, smiles: str, use_ogb_feat: bool = True) -> torch.Tensor:
        # compute atom features
        mol = self.get_mol(smiles)
        atom_features = self.get_atom_features_from_mol(mol, use_ogb_feat=use_ogb_feat)
        return atom_features

    @cache_decorator
    def get_atomic_numbers(self, smiles: str) -> torch.Tensor:
        # compute atomic numbers
        mol = self.get_mol(smiles)
        atomic_numbers = self.get_atomic_numbers_from_mol(mol)
        return atomic_numbers

    def get_atomic_numbers_from_mol(self, mol: Mol) -> torch.Tensor:
        atomic_numbers = torch.tensor(
            [atom.GetAtomicNum() for atom in mol.GetAtoms()],
            dtype=torch.int32,
        )
        return atomic_numbers

    def get_atom_features_from_mol(
        self, mol: Mol, use_ogb_feat: bool = True
    ) -> torch.Tensor:
        if use_ogb_feat:
            atom_features = torch.tensor(
                [atom_to_feature_vector(atom) for atom in mol.GetAtoms()],
                dtype=torch.float32,
            )
        else:
            atom_features = torch.tensor(
                [atom.GetFormalCharge() for atom in mol.GetAtoms()],
                dtype=torch.float32,
            ).view(-1, 1)
        return atom_features

    @cache_decorator
    def get_chiral_centers(self, smiles: str) -> torch.Tensor:
        # compute chiral centers
        mol = self.get_mol(smiles)
        chiral_index, chiral_nbr_index, chiral_tag = self.get_chiral_centers_from_mol(
            mol
        )

        self.cache[smiles]["chiral_centers"] = (
            chiral_index,
            chiral_nbr_index,
            chiral_tag,
        )
        return chiral_index, chiral_nbr_index, chiral_tag

    def get_chiral_centers_from_mol(self, mol: Mol) -> torch.Tensor:
        chiral_index, chiral_nbr_index, chiral_tag = get_chiral_tensors(mol)
        return chiral_index, chiral_nbr_index, chiral_tag

    @cache_decorator
    def get_mol_with_conformer(self, smiles: str, positions: torch.Tensor) -> Mol:
        mol = self.get_mol(smiles)
        mol.AddConformer(build_conformer(positions))
        return mol

    @cache_decorator
    def get_edge_index(
        self, smiles: str, use_edge_feat: bool
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns edge index and edge attributes for a given smiles."""
        # compute edge index
        mol = self.get_mol(smiles)
        edge_index, edge_attr = self.get_edge_index_from_mol(
            mol, use_edge_feat=use_edge_feat
        )

        self.cache[smiles]["edge_index"] = edge_index
        self.cache[smiles]["edge_attr"] = edge_attr
        return edge_index, edge_attr

    def get_edge_index_from_mol(
        self, mol: Mol, use_edge_feat: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns edge index and edge attributes for a given mol object."""
        edge_index, edge_attr = compute_edge_index(mol, with_edge_attr=use_edge_feat)
        return edge_index, edge_attr

    def get_data_from_smiles(self, smiles: str) -> Data:
        mol = get_mol_from_smiles(smiles)  # added hs
        smiles_changed = dm.to_smiles(
            mol,
            canonical=False,
            explicit_hs=True,
            with_atom_indices=True,
            isomeric=True,
        )
        node_attr = self.get_atom_features_from_mol(mol, True)
        chiral_index, chiral_nbr_index, chiral_tag = self.get_chiral_centers_from_mol(
            mol
        )
        edge_index, edge_attr = self.get_edge_index_from_mol(mol, False)
        atomic_numbers = self.get_atomic_numbers_from_mol(mol)

        graph = Data(
            atomic_numbers=atomic_numbers,
            smiles=smiles_changed,
            edge_index=edge_index,
            chiral_index=chiral_index,
            chiral_nbr_index=chiral_nbr_index,
            chiral_tag=chiral_tag,
            node_attr=node_attr,
            edge_attr=edge_attr,
        )
        return graph