import logging
import numpy as np
import torch

from rdkit import Chem
from rdkit import RDLogger
from federate.core.splitters.utils import \
    dirichlet_distribution_noniid_slice
from federate.core.splitters.graph.scaffold_splitter import \
    generate_scaffold
from federate.core.splitters import BaseSplitter

logger = logging.getLogger(__name__)

RDLogger.DisableLog('rdApp.*')


class GenFeatures:
    r"""Implementation of ``CanonicalAtomFeaturizer`` and
    ``CanonicalBondFeaturizer`` in DGL. \
    Source: https://lifesci.dgl.ai/_modules/dgllife/utils/featurizers.html

    Arguments:
        data: PyG.data in PyG.dataset.

    Returns:
        PyG.data: data passing featurizer.

    """
    def __init__(self):
        self.symbols = [
            'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
            'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag',
            'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni',
            'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'other'
        ]

        self.hybridizations = [
            Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D,
            Chem.rdchem.HybridizationType.SP3D2,
            'other',
        ]

        self.stereos = [
            Chem.rdchem.BondStereo.STEREONONE,
            Chem.rdchem.BondStereo.STEREOANY,
            Chem.rdchem.BondStereo.STEREOZ,
            Chem.rdchem.BondStereo.STEREOE,
            Chem.rdchem.BondStereo.STEREOCIS,
            Chem.rdchem.BondStereo.STEREOTRANS,
        ]

    def __call__(self, data, **kwargs):
        mol = Chem.MolFromSmiles(data.smiles)

        xs = []
        for atom in mol.GetAtoms():
            symbol = [0.] * len(self.symbols)
            if atom.GetSymbol() in self.symbols:
                symbol[self.symbols.index(atom.GetSymbol())] = 1.
            else:
                symbol[self.symbols.index('other')] = 1.
            degree = [0.] * 10
            degree[atom.GetDegree()] = 1.
            implicit = [0.] * 6
            implicit[atom.GetImplicitValence()] = 1.
            formal_charge = atom.GetFormalCharge()
            radical_electrons = atom.GetNumRadicalElectrons()
            hybridization = [0.] * len(self.hybridizations)
            if atom.GetHybridization() in self.hybridizations:
                hybridization[self.hybridizations.index(
                    atom.GetHybridization())] = 1.
            else:
                hybridization[self.hybridizations.index('other')] = 1.
            aromaticity = 1. if atom.GetIsAromatic() else 0.
            hydrogens = [0.] * 5
            hydrogens[atom.GetTotalNumHs()] = 1.

            x = torch.tensor(symbol + degree + implicit + [formal_charge] +
                             [radical_electrons] + hybridization +
                             [aromaticity] + hydrogens)
            xs.append(x)

        data.x = torch.stack(xs, dim=0)

        edge_attrs = []
        for bond in mol.GetBonds():
            bond_type = bond.GetBondType()
            single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
            double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
            triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
            aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
            conjugation = 1. if bond.GetIsConjugated() else 0.
            ring = 1. if bond.IsInRing() else 0.
            stereo = [0.] * 6
            stereo[self.stereos.index(bond.GetStereo())] = 1.

            edge_attr = torch.tensor(
                [single, double, triple, aromatic, conjugation, ring] + stereo)

            edge_attrs += [edge_attr, edge_attr]

        if len(edge_attrs) == 0:
            data.edge_index = torch.zeros((2, 0), dtype=torch.long)
            data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
        else:
            num_atoms = mol.GetNumAtoms()
            feats = torch.stack(edge_attrs, dim=0)
            feats = torch.cat([feats, torch.zeros(feats.shape[0], 1)], dim=1)
            self_loop_feats = torch.zeros(num_atoms, feats.shape[1])
            self_loop_feats[:, -1] = 1
            feats = torch.cat([feats, self_loop_feats], dim=0)
            data.edge_attr = feats

        return data


def gen_scaffold_lda_split(dataset, client_num=5, alpha=0.1):
    r"""
    return dict{ID:[idxs]}
    """
    logger.info('Scaffold split might take minutes, please wait...')
    scaffolds = {}
    for idx, data in enumerate(dataset):
        smiles = data.smiles
        _ = Chem.MolFromSmiles(smiles)
        scaffold = generate_scaffold(smiles)
        if scaffold not in scaffolds:
            scaffolds[scaffold] = [idx]
        else:
            scaffolds[scaffold].append(idx)
    # Sort from largest to smallest scaffold sets
    scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
    scaffold_list = [
        list(scaffold_set)
        for (scaffold,
             scaffold_set) in sorted(scaffolds.items(),
                                     key=lambda x: (len(x[1]), x[1][0]),
                                     reverse=True)
    ]
    label = np.zeros(len(dataset))
    for i in range(len(scaffold_list)):
        label[scaffold_list[i]] = i + 1
    label = torch.LongTensor(label)
    # Split data to list
    idx_slice = dirichlet_distribution_noniid_slice(label, client_num, alpha)
    return idx_slice


class ScaffoldLdaSplitter(BaseSplitter):
    """
    First adopt scaffold splitting and then assign the samples to \
    clients according to Latent Dirichlet Allocation.

    Arguments:
        dataset (List or PyG.dataset): The molecular datasets.
        alpha (float): Partition hyperparameter in LDA, smaller alpha \
            generates more extreme heterogeneous scenario see \
            ``np.random.dirichlet``

    Returns:
         List(List(PyG.data)): data_list of split dataset via scaffold split.

    """
    def __init__(self, client_num, alpha):
        super(ScaffoldLdaSplitter, self).__init__(client_num)
        self.alpha = alpha

    def __call__(self, dataset):
        featurizer = GenFeatures()
        data = []
        for ds in dataset:
            ds = featurizer(ds)
            data.append(ds)
        dataset = data
        idx_slice = gen_scaffold_lda_split(dataset, self.client_num,
                                           self.alpha)
        data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
        return data_list
