import rdkit.Chem.AllChem as Chem
import torch

from synthetic_coordinates.rdkit_helpers import smiles_to_mol
from synthetic_coordinates.conformer_generation import set_3D_coords_rdkit
from qm9.data.prepare.process_synthetic_coordinates import get_adj_list_from_adj_matrix
from conditional_generation.penalized_logP import compute_penalized_logP
from equivariant_diffusion.utils import remove_mean_with_mask

# ZINC250k but includes QM9
symbol_to_atomic_number = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'P': 15, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53}

class Molecule:
    """
    high-level molecule class that provides useful methods for switching between molecule representations.
    """
    def __init__(self, device, dtype, smiles: str = None, graph: dict = None, only_explicit_H: bool = True):
        """
        initializes a molecule object with either a smiles a graph representation
        Args:
        smiles (str): SMILES representation 
        graph (dict): dictionary containing atoms features and edge information and optinally other attributes
        """
        assert smiles is not None or graph is not None, "To initialize a molecule object, you need to provide either a SMILES or graph representation"
        
        self.smiles = smiles
        self.graph = graph
        self.only_explicit_H = only_explicit_H
        self.device = device
        self.dtype = dtype


    def get_graph(self, force_graph_recomputation=False):
        assert self.smiles is not None, "To get the graph representation you need to set a smiles representation"

        if self.graph is not None and not force_graph_recomputation:
            return self.graph
        
        self.mol = smiles_to_mol(self.smiles, kekulize=False, only_explicit_H=self.only_explicit_H)

        # compute coords, this will add all Hs
        self.mol = set_3D_coords_rdkit(self.mol)

        if self.only_explicit_H:
            self.mol = Chem.RemoveHs(self.mol, implicitOnly=True)

        num_atoms = self.mol.GetNumAtoms()
        atomic_numbers, positions, formal_charges = [], [], []
        atoms = self.mol.GetAtoms()
        for atom in atoms:
            symbol = atom.GetSymbol()
            formal_charge = atom.GetFormalCharge()
            coords = self.mol.GetConformer().GetAtomPosition(atom.GetIdx())

            atomic_numbers.append(symbol_to_atomic_number[symbol])
            positions.append([float(coords.x), float(coords.y), float(coords.z)])
            formal_charges.append(int(formal_charge))

        adj_matrix = Chem.rdmolops.GetAdjacencyMatrix(self.mol, useBO=True)
        assert adj_matrix.shape[0] == adj_matrix.shape[1] == num_atoms, "Adj matrix shape and num_atoms do not match"
        adj_list = get_adj_list_from_adj_matrix(adj_matrix)
        
        penalized_logP = compute_penalized_logP(self.smiles)

        graph = {'num_atoms': num_atoms, 'atomic_numbers': atomic_numbers, 'positions': positions, 'formal_charges': formal_charges, 'adj_matrix': adj_matrix, 'adj_list': adj_list, 'penalized_logP': penalized_logP}
        graph = {key: torch.tensor(val) for key, val in graph.items()}

        included_species = torch.unique(torch.Tensor(list(symbol_to_atomic_number.values())), sorted=True)
        graph['atomic_numbers_one_hot'] = graph['atomic_numbers'].unsqueeze(-1) == included_species.unsqueeze(0)
        assert torch.all(torch.any(graph['atomic_numbers_one_hot'], -1))

        possible_formal_charges = torch.Tensor([-1, 0, 1])
        graph['formal_charges_one_hot'] = graph['formal_charges'].unsqueeze(-1) == possible_formal_charges.unsqueeze(0)
        assert torch.all(torch.any(graph['formal_charges_one_hot'], -1))

        atom_mask = graph['atomic_numbers'] > 0
        graph['atom_mask'] = atom_mask

        edge_mask = atom_mask.unsqueeze(0) * atom_mask.unsqueeze(1)
        #mask diagonal
        diag_mask = ~torch.eye(edge_mask.size(0), dtype=torch.bool)
        edge_mask *= diag_mask
        graph['edge_mask'] = edge_mask

        for key in graph:
            # add singleton batch dimension
            graph[key] = graph[key].unsqueeze(0)
            # move to device
            if key != 'adj_matrix':
                graph[key] = graph[key].to(self.device, self.dtype)

        graph['positions'] = remove_mean_with_mask(graph['positions'], graph['atom_mask'].unsqueeze(2))

        self.graph = graph
        return self.graph

    def get_img(self,):
        from rdkit.Chem import Draw
        return Draw.MolToImage(self.mol)

class MoleculeBatch:
    """
    class for a batch of molecules
    """
    def __init__(self, batch_size: int, molecule: Molecule = None):
        if molecule is not None:
            graph = molecule.graph
            graph_batch = {key: graph[key].repeat_interleave(batch_size, dim=0) for key in graph}
            n_nodes = graph_batch['positions'].size(1)
            graph_batch['edge_mask'] = graph_batch['edge_mask'].view(batch_size*n_nodes*n_nodes, 1)
            self.graph = graph_batch
