from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
import torch
from optimization.props.properties import similarity


class BasicSmilesMetrics(object):
    def __init__(self, dataset_info, n_generated, dataset_smiles_list=None):
        self.atom_decoder = dataset_info['atom_decoder']
        self.dataset_smiles_list = dataset_smiles_list
        self.dataset_info = dataset_info
        self.n_generated = n_generated

        if dataset_smiles_list is None:
            with open(f'data/{dataset_info["name"]}/smiles/train.txt', 'r') as smiles_file:
                self.dataset_smiles_list = [line.strip() for line in smiles_file.readlines()]

    def evaluate(self, smiles_list):
        """ generated: list of pairs (positions: n x 3, atom_types: n [int], charges: n [int])
            the positions and atom types should already be masked. """
        valid, validity = self.compute_validity(smiles_list)
        print(f"Validity over {self.n_generated} molecules: {validity * 100 :.2f}%")

        if validity > 0:
            unique = list(set(valid))
            uniqueness = len(unique) / len(valid)
            print(f"Uniqueness over {len(valid)} valid molecules: {uniqueness * 100 :.2f}%")

            if self.dataset_smiles_list is not None:
                _, novelty = self.compute_novelty(unique)
                print(f"Novelty over {len(unique)} unique valid molecules: {novelty * 100 :.2f}%")
            else:
                novelty = 0.0
        else:
            novelty = 0.0
            uniqueness = 0.0
            unique = None
        return [validity, uniqueness, novelty], unique

    def compute_validity(self, smiles_list):
        valid = [smiles for smiles in smiles_list if is_valid(smiles)]
        return valid, len(valid) / self.n_generated


    def compute_novelty(self, unique):
        num_novel = 0
        novel = []
        for smiles in unique:
            if smiles not in self.dataset_smiles_list:
                novel.append(smiles)
                num_novel += 1
        return novel, num_novel / len(unique)

def is_valid(smiles: str):
    """
    Verifies whether a SMILES string corresponds to a valid molecule.

    Args:
        smiles: SMILES string

    Returns:
        True if the SMILES strings corresponds to a valid, non-empty molecule.
    """
    if smiles is None:
        return False

    mol = Chem.MolFromSmiles(smiles)

    return smiles != '' and mol is not None and mol.GetNumAtoms() > 0

def get_largest_connected_component(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None and '.' in smiles:
        # molecule is invalid but is disconnected
        # one of the submols might be valid
        new_smiles = ''
        for s in smiles.split('.'):
            if is_valid(s):
                new_smiles += '.' + s
        if len(new_smiles) == 0:
            return None
        new_smiles = new_smiles[1:]
        mol = Chem.MolFromSmiles(new_smiles)
    elif mol is None:
        return None

    try:
        # setup standardization module
        largest_Fragment = rdMolStandardize.LargestFragmentChooser()
        largest_mol = largest_Fragment.choose(mol)
        return Chem.MolToSmiles(largest_mol)
    except:
        print(f'smiles {smiles} is weird. Could not get largest connected component')
        return None

def canon_smiles(smiles):
    return Chem.MolToSmiles(Chem.MolFromSmiles(smiles))

def compute_diversity(smiles_list):
    if len(smiles_list) < 2:
        return 0.
    div = 0.0
    tot = 0
    for i in range(len(smiles_list)):
        for j in range(i + 1, len(smiles_list)):
            div += 1 - similarity(smiles_list[i], smiles_list[j])
            tot += 1
    div /= tot
    return div

bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, 
                    Chem.rdchem.BondType.AROMATIC]

def build_2D_mols(molecules_dict, dataset_info, use_ghost_nodes=False):
    """
    Creates a list of RDKit Mol objects from the generated atoms as well as the predicted strucutres using the edge model
    Args:
        molecules_dict (dict): contains the prediction of the diffusion model and the edge model. 
                               Has keys: 'positions', 'node_mask', 'atom_types', 'formal_charges', 'adjacency_matrices'
        dataset_info (dict)

    Returns:
        mol_list (list): list of 2D molecules
    """
    atom_types = molecules_dict['atom_types']
    positions = molecules_dict['positions']
    node_mask = molecules_dict['node_mask']
    formal_charges = molecules_dict['formal_charges']
    adjacency_matrices = molecules_dict['adjacency_matrices']

    atom_decoder = dataset_info["atom_decoder"]
    n_samples = len(positions)

    if use_ghost_nodes:
        # based on decoder's prediction, remove the nodes that are predicted as ghost nodes
        # atomsxmol = torch.sum(atom_types!=0, dim=1)
        ghost_nodes_masks = atom_types != 0
    else:
        if isinstance(node_mask, torch.Tensor):
            atomsxmol = torch.sum(node_mask, dim=1)
        else:
            atomsxmol = [torch.sum(m) for m in node_mask]

    mol_list = []
    for i in range(n_samples):
        try:
            atom_type = atom_types[i]
            charge = formal_charges[i]
            adj_matrix = adjacency_matrices[i]

            if use_ghost_nodes:
                ghost_nodes_mask = ghost_nodes_masks[i]
                atom_type = atom_type[ghost_nodes_mask]
                charge = charge[ghost_nodes_mask]
                adj_matrix = adj_matrix[ghost_nodes_mask, :][:, ghost_nodes_mask]
            else:
                atom_type = atom_type[0:int(atomsxmol[i])]
                charge = charge[0:int(atomsxmol[i])]
                adj_matrix = adj_matrix[0:int(atomsxmol[i]), 0:int(atomsxmol[i])]

            mol = Chem.RWMol()
            for atom, charge in zip(atom_type, charge):
                a = Chem.Atom(atom_decoder[atom.item()])
                a.SetFormalCharge(int(charge.item()))
                mol.AddAtom(a)
                
            # because adj_matrix is symmetric, we'll only look at the upper diagonal
            # to avoid processing each edge twice
            all_bonds = torch.nonzero(torch.triu(adj_matrix))
            for bond in all_bonds:
                mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[adj_matrix[bond[0], bond[1]].item()])

            mol_list.append(mol)
        except Exception as e:
            print(f'Caught an exception in build_2D_mols fn {str(e)}. Appending None to mol_list')
            mol_list.append(None)
            continue

    assert len(mol_list) == n_samples, "Lost a molecule somewhere!"
    return mol_list

def smiles_from_2d_mols_list(mol_list):
    smiles_list = []
    for i, mol in enumerate(mol_list):
        if mol is None:
            smiles_list.append('None')
            continue
        smiles = Chem.MolToSmiles(mol)
        if smiles is not None and smiles != '':
            smiles_list.append(smiles)
        else:
            smiles_list.append('None')
    assert len(smiles_list) == len(mol_list), "Lost a molecule somewhere!"
    return smiles_list
