from rdkit import Chem, RDLogger
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem, DataStructs
RDLogger.DisableLog('rdApp.*')

from fcd_torch import FCD as FCDMetric
from mini_moses.metrics.metrics import FragMetric, internal_diversity, ScafMetric, SNNMetric
from mini_moses.metrics.utils import get_mol, mapper

import pandas as pd
import os, gc, re, time, random
random.seed(0)

import torch
import numpy as np
from scipy.stats import wasserstein_distance
from multiprocessing import Pool

from analysis.nspdk import compute_nspdk
from metrics.property_metric import calculateSAS, SA
from utils import mols_to_nx


bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE,
                 Chem.rdchem.BondType.AROMATIC]
bond_decoder = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 
                3: Chem.rdchem.BondType.TRIPLE, 4: Chem.rdchem.BondType.AROMATIC}
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}


class BasicMolecularMetrics(object):
    def __init__(self, atom_decoder, train_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512):
        self.dataset_smiles_list = train_smiles
        self.atom_decoder = atom_decoder
        self.n_jobs = n_jobs
        self.device = device
        self.batch_size = batch_size
        self.stat_ref = stat_ref
        self.task_evaluator = task_evaluator
        self.dataset_smiles_list = train_smiles

    def compute_relaxed_validity(self, generated, task_list):
        valid = []
        num_components = []
        all_smiles = []
        valid_mols = []
        covered_atoms = set()
        direct_valid_count = 0
        for g, graph in enumerate(generated):
            atom_types, edge_types = graph
            mol = build_molecule_with_partial_charges(atom_types, edge_types, self.atom_decoder)
            direct_valid_flag = True if check_mol(mol, largest_connected_comp=True) is not None else False
            if direct_valid_flag:
                direct_valid_count += 1
            if task_list is not None and "O2" in task_list: ### polymers (adhere to GraphDiT's evaluation protocol)
                mol, _ = correct_mol_graphdit(mol)
            else:
                mol, _ = correct_mol_general(mol)
            smiles = mol2smiles(mol)
            mol = get_mol(smiles)

            try:
                mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
                num_components.append(len(mol_frags))
                largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
                smiles = mol2smiles(largest_mol)
                if smiles is not None and largest_mol is not None and len(smiles) > 1 and Chem.MolFromSmiles(smiles) is not None:
                    valid_mols.append(largest_mol)
                    valid.append(smiles)
                    for atom in largest_mol.GetAtoms():
                        covered_atoms.add(atom.GetSymbol())
                    all_smiles.append(smiles)

                else:
                    all_smiles.append(None)
            except Exception as e: 
                all_smiles.append(None)
                
        return valid, len(valid) / len(generated), direct_valid_count / len(generated), np.array(num_components), all_smiles, covered_atoms
            
    def compute_uniqueness(self, valid):
        """ valid: list of SMILES strings."""
        return list(set(valid)), len(set(valid)) / len(valid)

    def compute_novelty(self, unique):
        num_novel = 0
        novel = []
        if self.dataset_smiles_list is None:
            print("Dataset smiles is None, novelty computation skipped")
            return 1, 1
        for smiles in unique:
            if smiles not in self.dataset_smiles_list:
                novel.append(smiles)
                num_novel += 1
        return novel, num_novel / len(unique)

    def evaluate(self, generated, targets, active_atoms=None, test=False, ref_smiles=None):
        """ generated: list of pairs (positions: n x 3, atom_types: n [int])
            the positions and atom types should already be masked. """

        task_list = list(self.task_evaluator.keys()) if self.task_evaluator is not None else []
        valid, validity, nc_validity, num_components, all_smiles, covered_atoms = self.compute_relaxed_validity(generated, task_list)
        nc_mu = num_components.mean() if len(num_components) > 0 else 0
        nc_min = num_components.min() if len(num_components) > 0 else 0
        nc_max = num_components.max() if len(num_components) > 0 else 0

        len_active = len(active_atoms) if active_atoms is not None else 1
    
        cover_str = f"Cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}"
        print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}% (w/o correction: {nc_validity * 100 :.2f}%), cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}")
        print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")

        if validity > 0: 
            unique, uniqueness = self.compute_uniqueness(valid)
            novel, novelty = self.compute_novelty(unique)   
            dist_metrics = {'cover_str': cover_str ,'validity': validity, 'validity_nc': nc_validity, 
                            'uniqueness': uniqueness, 'novelty': novelty}
            unique = list(set(valid))
            close_pool = False
            if self.n_jobs != 1:
                pool = Pool(self.n_jobs)
                close_pool = True
            else:
                pool = 1
            valid_mols = mapper(pool)(get_mol, valid) 
            dist_metrics['internal_diversity'] = internal_diversity(valid_mols, pool, device=self.device)
            print(f"internal_diversity: {dist_metrics['internal_diversity']}")

            ref_mols = [Chem.MolFromSmiles(x) for x in ref_smiles]
            nspdk = compute_nspdk(mols_to_nx(ref_mols), mols_to_nx(valid_mols), methods=['nspdk'])['nspdk']
            dist_metrics['nspdk'] = nspdk
            print(f"nspdk: {dist_metrics['nspdk']}")
        
            start_time = time.time()
            if self.stat_ref is not None:
                kwargs = {'n_jobs': pool, 'device': self.device, 'batch_size': self.batch_size}
                kwargs_fcd = {'n_jobs': self.n_jobs, 'device': self.device, 'batch_size': self.batch_size}
                gc.collect()

                dist_metrics['sim/Frag'] = FragMetric(**kwargs)(gen=valid_mols, pref=self.stat_ref['Frag'])
                print(f"sim/Frag: {dist_metrics['sim/Frag']}")
                gc.collect()

                valid_mols = [mol for mol in valid_mols if mol is not None]
                dist_metrics['sim/Scaf'] = ScafMetric(**kwargs)(gen=valid_mols, pref=self.stat_ref['Scaf'])
                gc.collect()

                dist_metrics['dist/FCD'] = FCDMetric(**kwargs_fcd)(gen=valid, pref=self.stat_ref['FCD'])
                print(f"dist/FCD: {dist_metrics['dist/FCD']}")

            targets_log = {}
            if self.task_evaluator is not None:
                evaluation_list = list(self.task_evaluator.keys())
                evaluation_list = evaluation_list.copy()

                # assert 'meta_taskname' in evaluation_list
                if 'meta_taskname' in evaluation_list:
                    meta_taskname = self.task_evaluator['meta_taskname']
                    evaluation_list.remove('meta_taskname')
                    meta_split = meta_taskname.split('-')

                valid_index = np.array([True if smiles else False for smiles in all_smiles])
                targets_log = {}
                for i, name in enumerate(evaluation_list):
                    targets_log[f'input_{name}'] = np.array([float('nan')] * len(valid_index))
                    targets_log[f'input_{name}'] = targets[:, i]
                
                targets = targets[valid_index]
                if 'meta_taskname' in evaluation_list:
                    if len(meta_split) == 2:
                        cached_perm = {meta_split[0]: None, meta_split[1]: None}
                
                for i, name in enumerate(evaluation_list):
                    if name == 'scs':
                        continue
                    elif name == 'sas':
                        scores = calculateSAS(valid)
                    else:
                        scores = self.task_evaluator[name](valid)
                    targets_log[f'output_{name}'] = np.array([float('nan')] * len(valid_index))
                    targets_log[f'output_{name}'][valid_index] = scores
                    
                    if name in ['O2', 'N2', 'CO2']:
                        scores = np.maximum(scores, 1e-6)
                        if len(meta_split) == 2:
                            cached_perm[name] = scores
                        scores, cur_targets = np.log10(scores), np.log10(targets[:, i])
                        dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - cur_targets))
                    elif name == 'sas':
                        dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - targets[:, i]))
                    else:
                        true_y = targets[:, i]
                        predicted_labels = (scores >= 0.5).astype(int)
                        acc = (predicted_labels == true_y).sum() / len(true_y)
                        dist_metrics[f'{name}/acc'] = acc

            end_time = time.time()
            elapsed_time = end_time - start_time
            max_key_length = max(len(key) for key in dist_metrics)
            print(f'Details over {len(valid)} ({len(generated)}) valid (total) molecules, calculating metrics using {elapsed_time:.2f} s:')
            strs = ''
            for i, (key, value) in enumerate(dist_metrics.items()):
                if isinstance(value, (int, float, np.floating, np.integer)):
                    strs = strs + f'{key:>{max_key_length}}:{value:<7.4f}\t'
                if i % 4 == 3:
                    strs = strs + '\n'
            print(strs)

            if close_pool:
                pool.close()
                pool.join()
        else:
            unique = []
            dist_metrics = {}
            targets_log = None
        return unique, dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles, dist_metrics, targets_log


def mol2smiles(mol):
    if mol is None:
        return None
    try:
        Chem.SanitizeMol(mol)
    except ValueError:
        return None
    return Chem.MolToSmiles(mol)


def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False):
    if verbose:
        print("\nbuilding new molecule")

    mol = Chem.RWMol()
    for atom in atom_types:
        a = Chem.Atom(atom_decoder[atom.item()])
        mol.AddAtom(a)
        if verbose:
            print("Atom added: ", atom.item(), atom_decoder[atom.item()])

    edge_types = torch.triu(edge_types)
    all_bonds = torch.nonzero(edge_types)

    for i, bond in enumerate(all_bonds):
        if bond[0].item() != bond[1].item():
            try:
                mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()])
            except:
                import pdb; pdb.set_trace()
            if verbose:
                print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(),
                      bond_dict[edge_types[bond[0], bond[1]].item()])
            # add formal charge to atom: e.g. [O+], [N+], [S+]
            # not support [O-], [N-], [S-], [NH+] etc.
            flag, atomid_valence = check_valency(mol)
            if verbose:
                print("flag, valence", flag, atomid_valence)
            if flag:
                continue
            else:
                if len(atomid_valence) == 2:
                    idx = atomid_valence[0]
                    v = atomid_valence[1]
                    an = mol.GetAtomWithIdx(idx).GetAtomicNum()
                    if verbose:
                        print("atomic num of atom with a large valence", an)
                    if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
                        mol.GetAtomWithIdx(idx).SetFormalCharge(1)
                        # print("Formal charge added")
                else:
                    continue
    return mol


def check_valency(mol):
    try:
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
        return True, None
    except ValueError as e:
        e = str(e)
        p = e.find('#')
        e_sub = e[p:]
        atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
        return False, atomid_valence


def correct_mol_graphdit(mol):
    no_correct = False
    flag, _ = check_valency(mol)
    if flag:
        no_correct = True

    while True:
        mol_conn = connect_fragments(mol)
        # if mol_conn is not None:
        mol = mol_conn
        if mol is None:
            return None, no_correct
        flag, atomid_valence = check_valency(mol)
        if flag:
            break
        else:
            try:
                assert len(atomid_valence) == 2
                idx = atomid_valence[0]
                v = atomid_valence[1]
                queue = []
                check_idx = 0
                for b in mol.GetAtomWithIdx(idx).GetBonds():
                    type = int(b.GetBondType())
                    queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
                    if type == 12:
                        check_idx += 1
                queue.sort(key=lambda tup: tup[1], reverse=True)

                if queue[-1][1] == 12:
                    return None, no_correct
                elif len(queue) > 0:
                    start = queue[check_idx][2]
                    end = queue[check_idx][3]
                    t = queue[check_idx][1] - 1
                    mol.RemoveBond(start, end)
                    if t >= 1:
                        mol.AddBond(start, end, bond_dict[t])
            except Exception as e:
                return None, no_correct
    return mol, no_correct


def correct_mol_general(m):
    # xsm = Chem.MolToSmiles(x, isomericSmiles=True)
    mol = m

    #####
    no_correct = False
    flag, _ = check_valency(mol)
    if flag:
        no_correct = True

    while True:
        flag, atomid_valence = check_valency(mol)
        if flag:
            break
        else:
            assert len(atomid_valence) == 2
            idx = atomid_valence[0]
            v = atomid_valence[1]
            queue = []
            for b in mol.GetAtomWithIdx(idx).GetBonds():
                queue.append((b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
            queue.sort(key=lambda tup: tup[1], reverse=True)
            if len(queue) > 0:
                start = queue[0][2]
                end = queue[0][3]
                t = queue[0][1] - 1
                mol.RemoveBond(start, end)
                if t >= 1:
                    try:
                        mol.AddBond(start, end, bond_decoder[t])
                    except:
                        return None, no_correct
    return mol, no_correct


def check_mol(m, largest_connected_comp=True):
    if m is None:
        return None
    try:
        Chem.SanitizeMol(m)
        sm = Chem.MolToSmiles(m, isomericSmiles=True)
    except:
        return None
    if largest_connected_comp and '.' in sm:
        vsm = [(s, len(s)) for s in sm.split('.')]  # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.')
        vsm.sort(key=lambda tup: tup[1], reverse=True)
        mol = Chem.MolFromSmiles(vsm[0][0])
    else:
        mol = Chem.MolFromSmiles(sm)
    return mol


##### connect fragements
def select_atom_with_available_valency(frag):
    atoms = list(frag.GetAtoms())
    random.shuffle(atoms)
    for atom in atoms:
        if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0:
            return atom

    return None


def select_atoms_with_available_valency(frag):
    return [atom for atom in frag.GetAtoms() if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0]


def try_to_connect_fragments(combined_mol, frag, atom1, atom2):
    # Make copies of the molecules to try the connection
    trial_combined_mol = Chem.RWMol(combined_mol)
    trial_frag = Chem.RWMol(frag)
    
    # Add the new fragment to the combined molecule with new indices
    new_indices = {atom.GetIdx(): trial_combined_mol.AddAtom(atom) for atom in trial_frag.GetAtoms()}
    
    # Add the bond between the suitable atoms from each fragment
    trial_combined_mol.AddBond(atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE)
    
    # Adjust the hydrogen count of the connected atoms
    for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]:
        atom = trial_combined_mol.GetAtomWithIdx(atom_idx)
        num_h = atom.GetTotalNumHs()
        atom.SetNumExplicitHs(max(0, num_h - 1))
        
    # Add bonds for the new fragment
    for bond in trial_frag.GetBonds():
        trial_combined_mol.AddBond(new_indices[bond.GetBeginAtomIdx()], new_indices[bond.GetEndAtomIdx()], bond.GetBondType())
    
    # Convert to a Mol object and try to sanitize it
    new_mol = Chem.Mol(trial_combined_mol)
    try:
        Chem.SanitizeMol(new_mol)
        return new_mol  # Return the new valid molecule
    except Chem.MolSanitizeException:
        return None  # If the molecule is not valid, return None


def connect_fragments(mol):
    # Get the separate fragments
    frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
    if len(frags) < 2:
        return mol

    combined_mol = Chem.RWMol(frags[0])

    for frag in frags[1:]:
        # Select all atoms with available valency from both molecules
        atoms1 = select_atoms_with_available_valency(combined_mol)
        atoms2 = select_atoms_with_available_valency(frag)
        
        # Try to connect using all combinations of available valency atoms
        for atom1 in atoms1:
            for atom2 in atoms2:
                new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2)
                if new_mol is not None:
                    # If a valid connection is made, update the combined molecule and break
                    combined_mol = new_mol
                    break
            else:
                # Continue if the inner loop didn't break (no valid connection found for atom1)
                continue
            # Break if the inner loop did break (valid connection found)
            break
        else:
            # If no valid connections could be made with any of the atoms, return None
            return None

    return combined_mol


def compute_molecular_metrics(molecule_list, targets, train_smiles, stat_ref, dataset_info, task_evaluator, comput_config, 
                              test=False, ref_smiles=None):
    """ molecule_list: (dict) """

    atom_decoder = dataset_info.atom_decoder
    active_atoms = dataset_info.active_atoms
    metrics = BasicMolecularMetrics(atom_decoder, train_smiles, stat_ref, task_evaluator, **comput_config)
    evaluated_res = metrics.evaluate(molecule_list, targets, active_atoms, test=test, ref_smiles=ref_smiles)
    all_smiles = evaluated_res[-3]
    all_metrics = evaluated_res[-2]
    targets_log = evaluated_res[-1]
    unique_smiles = evaluated_res[0]

    return unique_smiles, all_smiles, all_metrics, targets_log