import logging
import numpy as np

from rdkit import Chem
from rdkit import RDLogger
from rdkit.Chem.Scaffolds import MurckoScaffold

from federate.core.splitters import BaseSplitter

logger = logging.getLogger(__name__)

RDLogger.DisableLog('rdApp.*')


def generate_scaffold(smiles, include_chirality=False):
    """return scaffold string of target molecule"""
    mol = Chem.MolFromSmiles(smiles)
    scaffold = MurckoScaffold\
        .MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
    return scaffold


def gen_scaffold_split(dataset, client_num=5):
    r"""
    return dict{ID:[idxs]}
    """
    logger.info('Scaffold split might take minutes, please wait...')
    scaffolds = {}
    for idx, data in enumerate(dataset):
        smiles = data.smiles
        _ = Chem.MolFromSmiles(smiles)
        scaffold = generate_scaffold(smiles)
        if scaffold not in scaffolds:
            scaffolds[scaffold] = [idx]
        else:
            scaffolds[scaffold].append(idx)
    # Sort from largest to smallest scaffold sets
    scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
    scaffold_list = [
        list(scaffold_set)
        for (scaffold,
             scaffold_set) in sorted(scaffolds.items(),
                                     key=lambda x: (len(x[1]), x[1][0]),
                                     reverse=True)
    ]
    scaffold_idxs = sum(scaffold_list, [])
    # Split data to list
    splits = np.array_split(scaffold_idxs, client_num)
    return [splits[ID] for ID in range(client_num)]


class ScaffoldSplitter(BaseSplitter):
    """
    Split molecular via scaffold. This splitter will sort all moleculars, and \
    split them into several parts.

    Arguments:
        client_num (int): Split data into client_num of pieces.
    """
    def __init__(self, client_num):
        super(ScaffoldSplitter, self).__init__(client_num)

    def __call__(self, dataset, **kwargs):
        dataset = [ds for ds in dataset]
        idx_slice = gen_scaffold_split(dataset)
        data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
        return data_list
