""" Base class for all data sets """
import argparse
from collections import defaultdict
from functools import partial
import json
import os
from typing import Dict, List, Union

import numpy as np
from ogb.lsc import PCQM4Mv2Dataset
from rdkit import Chem
from rdkit.Chem import Descriptors, QED
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import tqdm

from text2graph.data.base_dataset import BaseDataset, TextGraph
from text2graph.models.edge_feature_processing import get_disconnected_edge_indexes


PROPERTIES_TO_CALCULATE = {
    'a number of valence electrons': (Descriptors.NumValenceElectrons, int),
    'a weighted quantitative estimation of drug-likeness': (
        partial(
            QED.qed,
            w=QED.QEDproperties(MW=0.0, ALOGP=1.0, HBA=0.0, HBD=0.0, PSA=1.0, ROTB=1.0, AROM=1.0, ALERTS=0.0),
            qedProperties=None
        ),
        float
    )
}

NUM_DECIMALS = 2


class PCQM4MDataset(BaseDataset):
    """ Class for processing, loading and evaluating a model on pcqm4m data """
    def __getitem__(self, idx : int) -> TextGraph:
        """ Returns a boolean indicator of whether the data set exists at the specified local path
            there are multiple possible text prompts which are concatentated together for this
            data set
        """
        file_path = os.path.join(
            self.dataset_path,
            f"{self.file_prefix}_{self.split_name}_{idx}.json"
        )
        with open(file_path, 'r', encoding='utf-8') as data_json:
            graph = json.load(data_json)
        return TextGraph(
            text="".join(graph['text']) if isinstance(graph['text'], list) else graph['text'],
            nodes=graph['nodes'],
            edges=graph['edges'],
            edge_index=graph['edge_index']
        )

    @staticmethod
    def metric_names() -> List[str]:
        """ Returns a list of the names of evaluation metrics used to generate graphs """
        return ['parsability'] + list(PROPERTIES_TO_CALCULATE.keys())

    @staticmethod
    def download_and_process(parent_path: str) -> None:
        """ Processes the raw data of a data set into (text, graph) pairs where
            each graph is represented by three lists: nodes -> List[str], edges -> List[str],
            edge_index -> List[List[int]]. The processed data is saved as json files, one for each
            of the three splits of the data (train, val, test)
        """
        datasplit_ratios = [0.8, 0.1, 0.1]
        dataset_name = 'pcqm4m'
        dataset_path = os.path.join(parent_path, dataset_name)
        if not os.path.isdir(dataset_path):
            os.mkdir(dataset_path)
        raw_data_path = os.path.join(dataset_path, 'raw_data')
        if not os.path.isdir(raw_data_path):
            os.mkdir(raw_data_path)
        smiles_dataset = PCQM4Mv2Dataset(root=raw_data_path, only_smiles=True)
        split_dataset = {'train': [], 'val': [], 'test': []}
        assert list(split_dataset.keys()) == ['train', 'val', 'test']
        graph_dataset = []
        for smiles, _ in tqdm.tqdm(smiles_dataset):
            mol = Chem.MolFromSmiles(smiles, sanitize=True)
            if mol is None: continue
            graph = PCQM4MDataset.eval_rep2graph(mol)
            properties = PCQM4MDataset.calculate_properties(
                PCQM4MDataset.graphs2eval_rep(TextGraph(**graph))
            )
            graph['properties'] = PCQM4MDataset.calculate_properties(mol)
            graph['text'] = PCQM4MDataset.generate_text(graph['properties'])
            if len(graph['edge_index']) == 0:
                continue
            edge_indexes = get_disconnected_edge_indexes(np.array(graph['edge_index']))
            if (
                not all([value == graph['properties'][key] for key, value in properties.items()])
                or not len(graph['nodes']) == np.unique(graph['edge_index']).size
                or len(edge_indexes) != 1
            ):
                continue
            graph_dataset.append(graph)
        graph_ids = defaultdict(list)
        for idx, graph in enumerate(graph_dataset):
            gp = tuple(sorted([(key, value) for key, value in graph['properties'].items()]))
            graph_ids[gp].append(idx)
        for graph_idxs in graph_ids.values():
            split_idx = np.random.multinomial(1, datasplit_ratios, size=1)[0].argmax()
            split_name = list(split_dataset.keys())[split_idx]
            for graph_idx in graph_idxs:
                split_dataset[split_name].append(graph_dataset[graph_idx])
        for split_name, graphs in split_dataset.items():
            print(f"{split_name} data set contains {len(graphs)} graphs")
            for idx, graph in enumerate(split_dataset[split_name]):
                processed_file_name = f"{dataset_name}_{split_name}_{idx}.json"
                processed_file_path = os.path.join(dataset_path, processed_file_name)
                with open(processed_file_path, 'w', encoding='utf-8') as split_file:
                    json.dump(graph, split_file)

    @staticmethod
    def eval_rep2graph(data_point: Chem.Mol) -> Dict[str, Union[List[str], List[List[int]]]]:
        """ NOTE: each data point in the data set will have a representation from which it is
            possible to calculate metrics which measure model performance.

            Maps the point's evaluatable representation to its graph representation and returns the
            graph representation. Each graph is represented by three lists: nodes -> List[str],
            edges -> List[str], edge_index -> List[List[int]] contained within the output
            dictionary
        """
        atoms_list, bonds_list, edge_index = [], [], []
        for atom in data_point.GetAtoms():
            atoms_list.append(f"{atom.GetSymbol()} {atom.GetFormalCharge()} {atom.GetTotalNumHs()}")
        for bond in data_point.GetBonds():
            edge_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
            bonds_list.append(bond.GetBondType().name)
        return {'nodes': atoms_list, 'edges': bonds_list, 'edge_index': edge_index}

    @staticmethod
    def graph2eval_rep(graph: TextGraph) -> Chem.Mol:
        """ NOTE: each data point in the data set will have a representation from which it is
            possible to calculate metrics which measure model performance.

            Maps the point's graph to its evaluatable representation and returns the evaluatable
            representation.
        """
        molecule = Chem.RWMol()
        for node_str in graph.nodes:
            node_feats = node_str.split(" ")
            atom = Chem.Atom(node_feats[0])
            atom.SetNoImplicit(True)
            atom.SetFormalCharge(int(node_feats[1]))
            atom.SetNumExplicitHs(int(node_feats[2]))
            molecule.AddAtom(atom)
        for bond_type, (pred_idx, succ_idx) in zip(graph.edges, graph.edge_index):
            molecule.AddBond(pred_idx, succ_idx, Chem.BondType.names[bond_type])
        molecule = molecule.GetMol()
        Chem.SanitizeMol(molecule)
        return molecule

    def calculate_metrics(
        self,
        *,
        ground_truth_point: Chem.Mol,
        generated_point: Chem.Mol
    ) -> Dict[str, Union[int, float]]:
        """ Returns a dictionary of metrics comparing a generated point to its ground truth to
            measure the performance of the generative model
        """
        generated_properties = PCQM4MDataset.calculate_properties(generated_point)
        ground_truth_properties = PCQM4MDataset.calculate_properties(ground_truth_point)
        return {
            key: abs(ground_truth_properties[key] - value)
            for key, value in generated_properties.items()
        }

    @staticmethod
    def calculate_properties(molecule: Chem.Mol) -> Dict[str, Union[int, float]]:
        """ Calculates functional properites of a molecules secondary structure """
        return {
            key: dtype(np.round(function(molecule), NUM_DECIMALS))
            for key, (function, dtype) in PROPERTIES_TO_CALCULATE.items()
        }

    @staticmethod
    def generate_text(properties: Dict[str, Union[int, float]]) -> List[str]:
        """ Returns a list of functional descriptions of a molecule """
        return [f" A molecule with a {key} equal to {value}." for key, value in properties.items()]


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='data set location')
    parser.add_argument('--dataset-path', type=str, required=True)
    args = parser.parse_args()
    PCQM4MDataset.download_and_process(args.dataset_path)
