import numpy as np
import pandas as pd
import json
import torch
import networkx as nx

import re
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw

from fcd_torch import FCD


ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}
bond_decoder = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE}
AN_TO_SYMBOL = {6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 35: 'Br', 53: 'I'}


def mols_to_smiles(mols):
    return [Chem.MolToSmiles(mol) for mol in mols]


def smiles_to_mols(smiles):
    return [Chem.MolFromSmiles(s) for s in smiles]


def canonicalize_smiles(smiles):
    return [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles]


def canonicalize_smiles_for_fcd(smiles):
    canonicalized = []
    for s in smiles:
        try:
            canonicalized.append(Chem.MolToSmiles(Chem.MolFromSmiles(s)))
        except:
            continue
    return canonicalized


def load_smiles(dataset='QM9'):
    if dataset == 'QM9':
        col = 'SMILES1'
    elif dataset == 'ZINC250k':
        col = 'smiles'
    else:
        raise ValueError('wrong dataset name in load_smiles')
    
    df = pd.read_csv(f'data/{dataset.lower()}.csv')

    with open(f'data/valid_idx_{dataset.lower()}.json') as f:
        test_idx = json.load(f)
    
    if dataset == 'QM9':
        test_idx = test_idx['valid_idxs']
        test_idx = [int(i) for i in test_idx]
    
    train_idx = [i for i in range(len(df)) if i not in test_idx]

    return list(df[col].loc[train_idx]), list(df[col].loc[test_idx])


def get_molecular_scores(metrics, gen_mols, gen_smiles, dataset):
    scores = {}
    
    ### uniqueness
    if 'uni' in metrics:
        scores['uniqueness'] = len(set(gen_smiles)) / len(gen_smiles)
    
    if 'hard_nov' in metrics or 'soft_nov' in metrics or 'fcd' in metrics:
        train_smiles, test_smiles = load_smiles(dataset)
        train_mols = [Chem.MolFromSmiles(smi) for smi in train_smiles]
    
    ### hard novelty
    if 'hard_nov' in metrics:
        train_smiles = [Chem.MolToSmiles(mol) for mol in train_mols]    # canonicalize
        novel_smiles = [smi for smi in gen_smiles if smi not in train_smiles]
        scores['novelty'] = len(novel_smiles) / len(gen_smiles)

    if 'div' in metrics or 'soft_nov' in metrics:
        gen_fps = [AllChem.GetMorganFingerprintAsBitVect((mol), 2, 1024) for mol in gen_mols]

    ### diversity
    if 'div' in metrics:
        gen_fps_array = np.array(gen_fps)
        scores['diversity'] = 1 - (average_agg_tanimoto(gen_fps_array, gen_fps_array, agg='mean', device='cuda', p=2)).mean()

    ### soft novelty
    if 'soft_nov' in metrics:
        train_fps = [AllChem.GetMorganFingerprintAsBitVect((mol), 2, 1024) for mol in train_mols]
        num_novel_02, num_novel_03, num_novel_04 = 0, 0, 0
        for i in range(len(gen_fps)):
            sims = DataStructs.BulkTanimotoSimilarity(gen_fps[i], train_fps)
            if max(sims) < 0.2:
                num_novel_02 += 1
            if max(sims) < 0.3:
                num_novel_03 += 1
            if max(sims) < 0.4:
                num_novel_04 += 1
        novelty_02 = 1. * num_novel_02 / len(gen_fps)   #(len(gen_fps) + 1e-6)
        scores['novelty_02'] = min(novelty_02, 1.)
        novelty_03 = 1. * num_novel_03 / len(gen_fps)   #(len(gen_fps) + 1e-6)
        scores['novelty_03'] = min(novelty_03, 1.)
        novelty_04 = 1. * num_novel_04 / len(gen_fps)   #(len(gen_fps) + 1e-6)
        scores['novelty_04'] = min(novelty_04, 1.)

    ### FCD score
    if 'fcd' in metrics:
        fcd = FCD(device='cuda:', n_jobs=8)     # canonize smiles automatically
        scores['fcd'] = fcd(gen_smiles, test_smiles)

    return scores


def get_novelty_in_df(df, dataset):
    train_smiles, _ = load_smiles(dataset)
    train_mols = [Chem.MolFromSmiles(smi) for smi in train_smiles]
    
    ### hard novelty
    train_smiles = [Chem.MolToSmiles(mol) for mol in train_mols]    # canonicalize
    novel_smiles = [smi for smi in df['smiles'] if smi not in train_smiles]

    ### soft novelty
    if 'sim' not in df.keys():
        gen_fps = [AllChem.GetMorganFingerprintAsBitVect((mol), 2, 1024) for mol in df['mol']]
        train_fps = [AllChem.GetMorganFingerprintAsBitVect((mol), 2, 1024) for mol in train_mols]
        max_sims = []
        for i in range(len(gen_fps)):
            sims = DataStructs.BulkTanimotoSimilarity(gen_fps[i], train_fps)
            max_sims.append(max(sims))
        df['sim'] = max_sims

    return len(novel_smiles) / len(df)  # hard novelty


def get_novelty_in_df_low(df, dataset):
    assert dataset == 'ZINC250k'

    df_ref = pd.read_csv('data/zinc250k.csv')
    with open(f'data/low_idx_zinc250k_parp1_qed_sa.json') as f:
        train_idx = json.load(f)

    train_smiles = df_ref['smiles'].loc[train_idx]
    train_mols = [Chem.MolFromSmiles(smi) for smi in train_smiles]
    
    ### hard novelty
    train_smiles = [Chem.MolToSmiles(mol) for mol in train_mols]    # canonicalize
    novel_smiles = [smi for smi in df['smiles'] if smi not in train_smiles]

    ### soft novelty
    # if 'sim' not in df.keys():
    gen_fps = [AllChem.GetMorganFingerprintAsBitVect((mol), 2, 1024) for mol in df['mol']]
    train_fps = [AllChem.GetMorganFingerprintAsBitVect((mol), 2, 1024) for mol in train_mols]
    max_sims = []
    for i in range(len(gen_fps)):
        sims = DataStructs.BulkTanimotoSimilarity(gen_fps[i], train_fps)
        max_sims.append(max(sims))
    df['sim'] = max_sims

    return len(novel_smiles) / len(df)  # hard novelty


### code adapted from https://github.com/molecularsets/moses/blob/7b8f83b21a9b7ded493349ec8ef292384ce2bb52/moses/metrics/utils.py#L122
def average_agg_tanimoto(stock_vecs, gen_vecs,
                         batch_size=5000, agg='max',
                         device='cpu', p=1):
    """
    For each molecule in gen_vecs finds closest molecule in stock_vecs.
    Returns average tanimoto score for between these molecules
    Parameters:
        stock_vecs: numpy array <n_vectors x dim>
        gen_vecs: numpy array <n_vectors' x dim>
        agg: max or mean
        p: power for averaging: (mean x^p)^(1/p)
    """
    assert agg in ['max', 'mean'], "Can aggregate only max or mean"
    agg_tanimoto = np.zeros(len(gen_vecs))
    total = np.zeros(len(gen_vecs))
    for j in range(0, stock_vecs.shape[0], batch_size):
        x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
        for i in range(0, gen_vecs.shape[0], batch_size):
            y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
            y_gen = y_gen.transpose(0, 1)
            tp = torch.mm(x_stock, y_gen)
            jac = (tp / (x_stock.sum(1, keepdim=True) +
                         y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
            jac[np.isnan(jac)] = 1
            if p != 1:
                jac = jac**p
            if agg == 'max':
                agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
                    agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
            elif agg == 'mean':
                agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
                total[i:i + y_gen.shape[1]] += jac.shape[0]
    if agg == 'mean':
        agg_tanimoto /= total
    if p != 1:
        agg_tanimoto = (agg_tanimoto)**(1/p)
    return np.mean(agg_tanimoto)


def gen_mol(x, adj, dataset, largest_connected_comp=True):    
    # x: 32, 9, 5; adj: 32, 4, 9, 9
    x = x.detach().cpu().numpy()
    adj = adj.detach().cpu().numpy()

    if dataset == 'QM9':
        atomic_num_list = [6, 7, 8, 9, 0]
    else:
        atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0]
    # mols_wo_correction = [valid_mol_can_with_seg(construct_mol(x_elem, adj_elem, atomic_num_list)) for x_elem, adj_elem in zip(x, adj)]
    # mols_wo_correction = [mol for mol in mols_wo_correction if mol is not None]
    mols, num_no_correct, num_no_correct_fc = [], 0, 0
    for x_elem, adj_elem in zip(x, adj):
        mol, no_correct_fc = construct_mol(x_elem, adj_elem, atomic_num_list)
        flag, _ = check_valency(mol)
        if flag and no_correct_fc: num_no_correct_fc += 1
        cmol, no_correct = correct_mol(mol)
        if no_correct: num_no_correct += 1
        vcmol = valid_mol_can_with_seg(cmol, largest_connected_comp=largest_connected_comp)
        mols.append(vcmol)
    mols = [mol for mol in mols if mol is not None]
    return mols, num_no_correct, num_no_correct_fc


def construct_mol(x, adj, atomic_num_list): # x: 9, 5; adj: 4, 9, 9
    mol = Chem.RWMol()

    atoms = np.argmax(x, axis=1)
    atoms_exist = (atoms != len(atomic_num_list) - 1)
    atoms = atoms[atoms_exist]              # 9,
    for atom in atoms:
        mol.AddAtom(Chem.Atom(int(atomic_num_list[atom])))

    adj = np.argmax(adj, axis=0)            # 9, 9
    adj = adj[atoms_exist, :][:, atoms_exist]
    adj[adj == 3] = -1
    adj += 1                                # bonds 0, 1, 2, 3 -> 1, 2, 3, 0 (0 denotes the virtual bond)

    no_formal_charge = True     #####
    for start, end in zip(*np.nonzero(adj)):
        if start > end:
            mol.AddBond(int(start), int(end), bond_decoder[adj[start, end]])
            # add formal charge to atom: e.g. [O+], [N+], [S+]
            # not support [O-], [N-], [S-], [NH+] etc.
            flag, atomid_valence = check_valency(mol)
            if flag:
                continue
            else:
                no_formal_charge = False  #####
                assert len(atomid_valence) == 2
                idx = atomid_valence[0]
                v = atomid_valence[1]
                an = mol.GetAtomWithIdx(idx).GetAtomicNum()
                if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
                    mol.GetAtomWithIdx(idx).SetFormalCharge(1)
    return mol, no_formal_charge


def check_valency(mol):
    try:
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
        return True, None
    except ValueError as e:
        e = str(e)
        p = e.find('#')
        e_sub = e[p:]
        atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
        return False, atomid_valence


def correct_mol(m):
    # xsm = Chem.MolToSmiles(x, isomericSmiles=True)
    mol = m

    #####
    no_correct = False
    flag, _ = check_valency(mol)
    if flag:
        no_correct = True

    while True:
        flag, atomid_valence = check_valency(mol)
        if flag:
            break
        else:
            assert len(atomid_valence) == 2
            idx = atomid_valence[0]
            v = atomid_valence[1]
            queue = []
            for b in mol.GetAtomWithIdx(idx).GetBonds():
                queue.append((b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
            queue.sort(key=lambda tup: tup[1], reverse=True)
            if len(queue) > 0:
                start = queue[0][2]
                end = queue[0][3]
                t = queue[0][1] - 1
                mol.RemoveBond(start, end)
                if t >= 1:
                    mol.AddBond(start, end, bond_decoder[t])
    return mol, no_correct


def valid_mol_can_with_seg(m, largest_connected_comp=True):
    if m is None:
        return None
    sm = Chem.MolToSmiles(m, isomericSmiles=True)
    if largest_connected_comp and '.' in sm:
        vsm = [(s, len(s)) for s in sm.split('.')]  # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.')
        vsm.sort(key=lambda tup: tup[1], reverse=True)
        mol = Chem.MolFromSmiles(vsm[0][0])
    else:
        mol = Chem.MolFromSmiles(sm)
    return mol


#####
def mols_to_nx(mols):
    nx_graphs = []
    for mol in mols:
        G = nx.Graph()

        for atom in mol.GetAtoms():
            G.add_node(atom.GetIdx(),
                       label=atom.GetSymbol())
                    #    atomic_num=atom.GetAtomicNum(),
                    #    formal_charge=atom.GetFormalCharge(),
                    #    chiral_tag=atom.GetChiralTag(),
                    #    hybridization=atom.GetHybridization(),
                    #    num_explicit_hs=atom.GetNumExplicitHs(),
                    #    is_aromatic=atom.GetIsAromatic())
                    
        for bond in mol.GetBonds():
            G.add_edge(bond.GetBeginAtomIdx(),
                       bond.GetEndAtomIdx(),
                       label=int(bond.GetBondTypeAsDouble()))
                    #    bond_type=bond.GetBondType())
        
        nx_graphs.append(G)
    return nx_graphs


def x_adj_to_nx(x, adj, dataset):
    # x: 32, 9, 5; adj: 32, 9, 9
    x_ = x.detach().cpu().numpy()
    x_ = np.argmax(x_, axis=-1)         # 32, 9
    
    if len(adj.shape) == 4:             # 32, 4, 9, 9
        if isinstance(adj, torch.Tensor):
            adj_ = adj.detach().cpu().numpy()
        adj_ = np.argmax(adj_, axis=1)
        adj_[adj_ == 3] = -1
        adj_ += 1
    else:                               # 32, 9, 9
        adj_ = adj
        
    if dataset == 'QM9':
        atomic_num_list = [6, 7, 8, 9]
    else:
        atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53]

    nx_graphs = []
    for x_single, adj_single in zip(x_, adj_):
        G = nx.from_numpy_matrix(adj_single)

        for edge in G.edges:    # convert 'weight' to 'label'
            G.edges[edge]['label'] = G.edges[edge]['weight']
            del G.edges[edge]['weight']

        for i, ai in enumerate(x_single):
            if ai != len(atomic_num_list):
                G.add_node(i, label=AN_TO_SYMBOL[atomic_num_list[ai]])
            else:   # a virtual node
                G.remove_node(i)

        nx_graphs.append(G)
    return nx_graphs
