from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import BondType as BT

import os
import os.path as osp
import pathlib
from typing import Any, Sequence

import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from torch_geometric.data import Data, InMemoryDataset, download_url
import pandas as pd

import utils
from analysis.rdkit_functions import (
    mol2smiles,
    build_molecule_with_partial_charges,
    compute_molecular_metrics,
)
from datasets.abstract_dataset import AbstractDatasetInfos, MolecularDataModule


def to_list(value: Any) -> Sequence:
    if isinstance(value, Sequence) and not isinstance(value, str):
        return value
    else:
        return [value]


atom_decoder = ["C", "N", "S", "O", "F", "Cl", "Br", "H"]


class MOSESDataset(InMemoryDataset):
    train_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/train.csv"
    val_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/test.csv"
    test_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/test_scaffolds.csv"

    def __init__(
        self,
        stage,
        root,
        filter_dataset: bool,
        transform=None,
        pre_transform=None,
        pre_filter=None,
    ):
        self.stage = stage
        self.atom_decoder = atom_decoder
        self.filter_dataset = filter_dataset
        if self.stage == "train":
            self.file_idx = 0
        elif self.stage == "val":
            self.file_idx = 1
        else:
            self.file_idx = 2
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[self.file_idx])

    @property
    def raw_file_names(self):
        return ["train_moses.csv", "val_moses.csv", "test_moses.csv"]

    @property
    def split_file_name(self):
        return ["train_moses.csv", "val_moses.csv", "test_moses.csv"]

    @property
    def split_paths(self):
        files = to_list(self.split_file_name)
        return [osp.join(self.raw_dir, f) for f in files]

    @property
    def processed_file_names(self):
        if self.filter_dataset:
            return [
                "train_filtered.pt",
                "test_filtered.pt",
                "test_scaffold_filtered.pt",
            ]
        else:
            return ["train.pt", "test.pt", "test_scaffold.pt"]

    def download(self):
        import rdkit        

        train_path = download_url(self.train_url, self.raw_dir)
        os.rename(train_path, osp.join(self.raw_dir, "train_moses.csv"))

        test_path = download_url(self.test_url, self.raw_dir)
        os.rename(test_path, osp.join(self.raw_dir, "val_moses.csv"))

        valid_path = download_url(self.val_url, self.raw_dir)
        os.rename(valid_path, osp.join(self.raw_dir, "test_moses.csv"))

    def process(self):
        RDLogger.DisableLog("rdApp.*")
        types = {atom: i for i, atom in enumerate(self.atom_decoder)}

        bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

        path = self.split_paths[self.file_idx]
        smiles_list = pd.read_csv(path)["SMILES"].values

        data_list = []
        smiles_kept = []

        for i, smile in enumerate(tqdm(smiles_list)):
            mol = Chem.MolFromSmiles(smile)
            N = mol.GetNumAtoms()

            type_idx = []
            for atom in mol.GetAtoms():
                type_idx.append(types[atom.GetSymbol()])

            row, col, edge_type = [], [], []
            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                row += [start, end]
                col += [end, start]
                edge_type += 2 * [bonds[bond.GetBondType()] + 1]

            if len(row) == 0:
                continue

            edge_index = torch.tensor([row, col], dtype=torch.long)
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = F.one_hot(edge_type, num_classes=len(bonds) + 1).to(torch.float)

            perm = (edge_index[0] * N + edge_index[1]).argsort()
            edge_index = edge_index[:, perm]
            edge_attr = edge_attr[perm]

            x = F.one_hot(torch.tensor(type_idx), num_classes=len(types)).float()
            y = torch.zeros(size=(1, 0), dtype=torch.float)

            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)

            if self.filter_dataset:
                dense_data, node_mask = utils.to_dense(
                    data.x, data.edge_index, data.edge_attr, data.batch
                )
                dense_data = dense_data.mask(node_mask, collapse=True)
                X, E = dense_data.X, dense_data.E

                assert X.size(0) == 1
                atom_types = X[0]
                edge_types = E[0]
                mol = build_molecule_with_partial_charges(
                    atom_types, edge_types, atom_decoder
                )
                smiles = mol2smiles(mol)
                if smiles is not None:
                    try:
                        mol_frags = Chem.rdmolops.GetMolFrags(
                            mol, asMols=True, sanitizeFrags=True
                        )
                        if len(mol_frags) == 1:
                            data_list.append(data)
                            smiles_kept.append(smiles)

                    except Chem.rdchem.AtomValenceException:
                        print("Valence error in GetmolFrags")
                    except Chem.rdchem.KekulizeException:
                        print("Can't kekulize molecule")
            else:
                if self.pre_filter is not None and not self.pre_filter(data):
                    continue
                if self.pre_transform is not None:
                    data = self.pre_transform(data)
                data_list.append(data)

        torch.save(self.collate(data_list), self.processed_paths[self.file_idx])

        if self.filter_dataset:
            smiles_save_path = osp.join(
                pathlib.Path(self.raw_paths[0]).parent, f"new_{self.stage}.smiles"
            )
            print(smiles_save_path)
            with open(smiles_save_path, "w") as f:
                f.writelines("%s\n" % s for s in smiles_kept)
            print(f"Number of molecules kept: {len(smiles_kept)} / {len(smiles_list)}")


class MosesDataModule(MolecularDataModule):
    def __init__(self, cfg):
        self.remove_h = False
        self.datadir = cfg.dataset.datadir
        self.filter_dataset = cfg.dataset.filter
        self.train_smiles = []
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)
        datasets = {
            "train": MOSESDataset(
                stage="train", root=root_path, filter_dataset=self.filter_dataset
            ),
            "val": MOSESDataset(
                stage="val", root=root_path, filter_dataset=self.filter_dataset
            ),
            "test": MOSESDataset(
                stage="test", root=root_path, filter_dataset=self.filter_dataset
            ),
        }
        super().__init__(cfg, datasets)


class MOSESinfos(AbstractDatasetInfos):
    def __init__(self, datamodule, cfg, recompute_statistics=False, meta=None):
        self.name = "MOSES"
        self.input_dims = None
        self.output_dims = None
        self.remove_h = False
        self.compute_fcd = cfg.dataset.compute_fcd

        self.atom_decoder = atom_decoder
        self.atom_encoder = {atom: i for i, atom in enumerate(self.atom_decoder)}
        self.atom_weights = {0: 12, 1: 14, 2: 32, 3: 16, 4: 19, 5: 35.4, 6: 79.9, 7: 1}
        self.valencies = [4, 3, 4, 2, 1, 1, 1, 1]
        self.num_atom_types = len(self.atom_decoder)
        self.max_weight = 350

        meta_files = dict(
            n_nodes=f"{self.name}_n_counts.txt",
            node_types=f"{self.name}_atom_types.txt",
            edge_types=f"{self.name}_edge_types.txt",
            valency_distribution=f"{self.name}_valencies.txt",
        )

        self.n_nodes = torch.tensor(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                3.097634362347889692e-06,
                1.858580617408733815e-05,
                5.007842264603823423e-05,
                5.678996240021660924e-05,
                1.244216400664299726e-04,
                4.486406978685408831e-04,
                2.253012731671333313e-03,
                3.231865121051669121e-03,
                6.709992419928312302e-03,
                2.289564721286296844e-02,
                5.411050841212272644e-02,
                1.099515631794929504e-01,
                1.223291903734207153e-01,
                1.280680745840072632e-01,
                1.445975750684738159e-01,
                1.505961418151855469e-01,
                1.436946094036102295e-01,
                9.265746921300888062e-02,
                1.820066757500171661e-02,
                2.065089574898593128e-06,
            ]
        )
        self.max_n_nodes = len(self.n_nodes) - 1 if self.n_nodes is not None else None
        self.node_types = torch.tensor(
            [0.722338, 0.13661, 0.163655, 0.103549, 0.1421803, 0.005411, 0.00150, 0.0]
        )
        self.edge_types = torch.tensor(
            [0.89740, 0.0472947, 0.062670, 0.0003524, 0.0486]
        )
        self.valency_distribution = torch.zeros(3 * self.max_n_nodes - 2)
        self.valency_distribution[:7] = torch.tensor(
            [0.0, 0.1055, 0.2728, 0.3613, 0.2499, 0.00544, 0.00485]
        )

        if meta is None:
            meta = dict(
                n_nodes=None,
                node_types=None,
                edge_types=None,
                valency_distribution=None,
            )
        assert set(meta.keys()) == set(meta_files.keys())
        for k, v in meta_files.items():
            if (k not in meta or meta[k] is None) and os.path.exists(v):
                meta[k] = np.loadtxt(v)
                setattr(self, k, meta[k])
        if recompute_statistics or self.n_nodes is None:
            self.n_nodes = datamodule.node_counts()
            print("Distribution of number of nodes", self.n_nodes)
            np.savetxt(meta_files["n_nodes"], self.n_nodes.numpy())
            self.max_n_nodes = len(self.n_nodes) - 1
        if recompute_statistics or self.node_types is None:
            self.node_types = datamodule.node_types()                           
            print("Distribution of node types", self.node_types)
            np.savetxt(meta_files["node_types"], self.node_types.numpy())

        if recompute_statistics or self.edge_types is None:
            self.edge_types = datamodule.edge_counts()
            print("Distribution of edge types", self.edge_types)
            np.savetxt(meta_files["edge_types"], self.edge_types.numpy())
        if recompute_statistics or self.valency_distribution is None:
            valencies = datamodule.valency_count(self.max_n_nodes)
            print("Distribution of the valencies", valencies)
            np.savetxt(meta_files["valency_distribution"], valencies.numpy())
            self.valency_distribution = valencies
        self.complete_infos(n_nodes=self.n_nodes, node_types=self.node_types)


def get_smiles(raw_dir, filter_dataset):

    if filter_dataset:
        smiles_save_paths = {
            "train": osp.join(raw_dir, "new_train.smiles"),
            "val": osp.join(raw_dir, "new_val.smiles"),
            "test": osp.join(raw_dir, "new_test.smiles"),
        }
        train_smiles = open(smiles_save_paths["train"]).readlines()
        val_smiles = open(smiles_save_paths["val"]).readlines()
        test_smiles = open(smiles_save_paths["test"]).readlines()

    else:
        smiles_save_paths = {
            "train": osp.join(raw_dir, "train_moses.csv"),
            "val": osp.join(raw_dir, "val_moses.csv"),
            "test": osp.join(raw_dir, "test_moses.csv"),
        }
        train_smiles = extract_smiles_from_csv(smiles_save_paths["train"])
        val_smiles = extract_smiles_from_csv(smiles_save_paths["val"])
        test_smiles = extract_smiles_from_csv(smiles_save_paths["test"])

    return {
        "train": train_smiles,
        "val": val_smiles,
        "test": test_smiles,
    }


def extract_smiles_from_csv(csv_path):
    return pd.read_csv(csv_path)["SMILES"].to_list()


if __name__ == "__main__":
    ds = [
        MOSESDataset(
            s,
            os.path.join(os.path.abspath(__file__), "../../../data/moses"),
            preprocess=True,
        )
        for s in ["train", "val", "test"]
    ]
