import json
import os
import os.path as osp
import pathlib
from typing import Any, Sequence
from logging import getLogger
import traceback
import torch
import torch.nn.functional as F
from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import rdmolops
from tqdm import tqdm
import numpy as np
import numpy
import pandas as pd
from collections import Counter
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
from torch_geometric.utils import subgraph

import src.utils as utils
from src.datasets.abstract_dataset import MolecularDataModule, AbstractDatasetInfos
from src.analysis.rdkit_functions import mol2smiles, build_molecule_with_partial_charges
from src.analysis.rdkit_functions import compute_molecular_metrics


def files_exist(files) -> bool:
    # NOTE: We return `False` in case `files` is empty, leading to a
    # re-processing of files on every instantiation.
    return len(files) != 0 and all([osp.exists(f) for f in files])


def to_list(value: Any) -> Sequence:
    if isinstance(value, Sequence) and not isinstance(value, str):
        return value
    else:
        return [value]


class RemoveYTransform:
    def __call__(self, data):
        data.y = torch.zeros((1, 0), dtype=torch.float)
        return data


class SelectMuTransform:
    def __call__(self, data):
        data.y = data.y[..., :1]
        return data


class SelectHOMOTransform:
    def __call__(self, data):
        data.y = data.y[..., 1:]
        return data


class ZincDataset(InMemoryDataset):
    data_url = "https://raw.githubusercontent.com/harryjo97/DruM/tree/master/DruM_2D/data/zinc250k.csv"
    idx_url = "https://raw.githubusercontent.com/harryjo97/DruM/tree/master/DruM_2D/data/valid_idx_zinc250k.json"

    def __init__(
        self,
        stage,
        root,
        remove_h: bool,
        target_prop=None,
        transform=None,
        pre_transform=None,
        pre_filter=None,
        kekulize=True,
    ):
        self.target_prop = target_prop
        self.stage = stage
        if self.stage == "train":
            self.file_idx = 0
        elif self.stage == "val":
            self.file_idx = 1
        else:
            self.file_idx = 2
        self.remove_h = remove_h
        self.kekulize = kekulize
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[self.file_idx])
        self.smiles = json.load(open(self.split_smiles_path[self.file_idx], "r"))

    @property
    def raw_file_names(self):
        return ["zinc250k.csv", "valid_idx_zinc250k.json"]

    @property
    def split_file_name(self):
        return ["train.csv", "valid.csv", "test.csv"]

    @property
    def split_paths(self):
        r"""The absolute filepaths that must be present in order to skip
        splitting."""
        files = to_list(self.split_file_name)
        return [osp.join(self.raw_dir, f) for f in files]

    @property
    def split_smiles_path(self):
        r"""The absolute filepaths that must be present in order to skip
        splitting."""
        return [
            osp.join(self.processed_dir, "train_smiles.json"),
            osp.join(self.processed_dir, "valid_smiles.json"),
            osp.join(self.processed_dir, "test_smiles.json"),
        ]

    @property
    def processed_file_names(self):
        if self.remove_h:
            return ["proc_tr_no_h.pt", "proc_val_no_h.pt", "proc_test_no_h.pt"]
        else:
            return ["proc_tr_h.pt", "proc_val_h.pt", "proc_test_h.pt"]

    def download(self):
        """
        Download raw zinc file, and the trainset_index
        """
        download_url(self.data_url, self.raw_dir)

        data_frame = pd.read_csv(self.raw_paths[0], index_col=0)
        with open(self.raw_paths[1], "r") as f:
            valid_indexes = json.load(f)
        # Get the index column and convert it into a list
        index_list = data_frame.index.tolist()
        train_indexes = [idx for idx in index_list if idx not in valid_indexes]

        data_frame.iloc[train_indexes].to_csv(self.split_paths[0])
        data_frame.iloc[valid_indexes].to_csv(self.split_paths[1])
        data_frame.iloc[valid_indexes].to_csv(self.split_paths[2])

    def process(self):
        def prepare_smiles_and_mol(mol, remove_h, kekulize):
            canonical_smiles = Chem.MolToSmiles(
                mol, isomericSmiles=False, canonical=True
            )
            mol = Chem.MolFromSmiles(canonical_smiles)
            if not remove_h:
                mol = Chem.AddHs(mol)
            if kekulize:
                Chem.Kekulize(mol)
            return canonical_smiles, mol

        if not osp.exists(self.split_paths[self.file_idx]):
            data_frame = pd.read_csv(self.raw_paths[0], index_col=0)
            with open(self.raw_paths[1], "r") as f:
                valid_indexes = json.load(f)
            # Get the index column and convert it into a list
            index_list = data_frame.index.tolist()
            train_indexes = [idx for idx in index_list if idx not in valid_indexes]

            data_frame.iloc[train_indexes].to_csv(self.split_paths[0])
            data_frame.iloc[valid_indexes].to_csv(self.split_paths[1])
            data_frame.iloc[valid_indexes].to_csv(self.split_paths[2])

        # types = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4}
        # {6: "C", 8: "O", 7: "N", 9: "F", 15: "P", 16: "S", 17: "Cl", 35: "Br", 53: "I"}
        types = {
            "H": 0,
            "C": 1,
            "N": 2,
            "O": 3,
            "F": 4,
            "P": 5,
            "S": 6,
            "Cl": 7,
            "Br": 8,
            "I": 9,
        }
        bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
        symbol2weight = {
            "H": 1,
            "C": 12,
            "N": 14,
            "O": 16,
            "F": 19,
            "P": 31,
            "S": 32,
            "Cl": 35,
            "Br": 80,
            "I": 127,
        }
        target_df = pd.read_csv(self.split_paths[self.file_idx], index_col=0)
        data_list = []
        smiles_df = target_df["smiles"].tolist()
        idxes = target_df.index.tolist()
        self.smiles = []
        for idx, smiles in tqdm(zip(idxes, smiles_df)):
            mol = Chem.MolFromSmiles(smiles)
            canonical_smiles, mol = prepare_smiles_and_mol(
                mol, self.remove_h, self.kekulize
            )
            self.smiles.append(canonical_smiles)
            N = mol.GetNumAtoms()

            type_idx = []
            mol_weight = 0
            for atom in mol.GetAtoms():
                type_idx.append(types[atom.GetSymbol()])

            row, col, edge_type = [], [], []
            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                row += [start, end]
                col += [end, start]
                edge_type += 2 * [bonds[bond.GetBondType()] + 1]

            edge_index = torch.tensor([row, col], dtype=torch.long)
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = F.one_hot(edge_type, num_classes=len(bonds) + 1).to(torch.float)

            perm = (edge_index[0] * N + edge_index[1]).argsort()
            edge_index = edge_index[:, perm]
            edge_attr = edge_attr[perm]

            x = F.one_hot(torch.tensor(type_idx), num_classes=len(types)).float()
            y = torch.zeros((1, 0), dtype=torch.float)

            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=idx)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)

        torch.save(self.collate(data_list), self.processed_paths[self.file_idx])
        json.dump(self.smiles, open(self.split_smiles_path[self.file_idx], "w"))


class ZincDataModule(MolecularDataModule):
    def __init__(self, cfg):
        self.datadir = cfg.dataset.datadir
        self.remove_h = cfg.dataset.remove_h

        target = getattr(cfg.general, "guidance_target", None)
        regressor = getattr(self, "regressor", None)
        if regressor and target == "mu":
            transform = SelectMuTransform()
        elif regressor and target == "homo":
            transform = SelectHOMOTransform()
        elif regressor and target == "both":
            transform = None
        else:
            transform = RemoveYTransform()

        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)
        datasets = {
            "train": ZincDataset(
                stage="train",
                root=root_path,
                remove_h=cfg.dataset.remove_h,
                target_prop=target,
                transform=RemoveYTransform(),
            ),
            "val": ZincDataset(
                stage="val",
                root=root_path,
                remove_h=cfg.dataset.remove_h,
                target_prop=target,
                transform=RemoveYTransform(),
            ),
            "test": ZincDataset(
                stage="test",
                root=root_path,
                remove_h=cfg.dataset.remove_h,
                target_prop=target,
                transform=transform,
            ),
        }
        super().__init__(cfg, datasets)


class Zincinfos(AbstractDatasetInfos):
    def __init__(self, datamodule, cfg, recompute_statistics=False):
        self.remove_h = cfg.dataset.remove_h
        self.need_to_strip = (
            False  # to indicate whether we need to ignore one output from the model
        )

        self.name = "zinc"
        self.atom_encoder = {
            "H": 0,
            "C": 1,
            "N": 2,
            "O": 3,
            "F": 4,
            "P": 5,
            "S": 6,
            "Cl": 7,
            "Br": 8,
            "I": 9,
        }
        # valencies = {"H":1, "C":4, "N":3, "O":2, "F":1, "P":3, "S":2, "Cl":1, "Br":1, "I":1}
        # weights = {"H":1, "C":12, "N":14, "O":16, "F":19, "P":31, "S":32, "Cl":35, "Br":80, "I":127}
        self.atom_decoder = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I"]
        self.valencies = [1, 4, 3, 2, 1, 3, 2, 1, 1, 1]
        self.num_atom_types = len(self.atom_decoder)
        self.atom_weights = {
            0: 1,
            1: 12,
            2: 14,
            3: 16,
            4: 19,
            5: 31,
            6: 32,
            7: 35,
            8: 80,
            9: 127,
        }
        if self.remove_h:
            self.max_n_nodes = 39
            self.max_weight = 497
            self.n_nodes = torch.tensor(
                [
                    0.0,
                    0.0,
                    0.0,
                    0.0,
                    0.0,
                    0.0,
                    1.3358982579886716e-05,
                    2.226497096647786e-05,
                    5.788892451284243e-05,
                    0.0002983506109508033,
                    0.0007926329664066118,
                    0.002912258202415304,
                    0.004689002885540237,
                    0.0071515086744326885,
                    0.011274981297424389,
                    0.017117309679028178,
                    0.02535980193081828,
                    0.03501389334188308,
                    0.046707456093477255,
                    0.058178369135406645,
                    0.07082932563855937,
                    0.08147198176053579,
                    0.074921627302198,
                    0.0843842399629511,
                    0.09309874959923052,
                    0.09145114174771116,
                    0.07717484236400556,
                    0.06339727832994906,
                    0.040330768408678,
                    0.031130882405329345,
                    0.024393502190873145,
                    0.01923693491503687,
                    0.015028855402372556,
                    0.010362117487798797,
                    0.006915499982188023,
                    0.004119019628798404,
                    0.0015941719211998147,
                    0.0005610772683552421,
                    8.905988386591144e-06,
                ]
            )
            self.node_types = torch.tensor(
                [
                    0.0,
                    0.7367823659857735,
                    0.12211220550781154,
                    0.09974558113849699,
                    0.013745389175070837,
                    2.442821154523441e-05,
                    0.017806050387129447,
                    0.007423098739792806,
                    0.0022056559196000237,
                    0.0001552249347795604,
                ]
            )
            self.edge_types = torch.tensor(
                [
                    0.7430251644317138,
                    0.25446130720284316,
                    0.0025135283654429907,
                    0.0,
                    0.0,
                ]
            )
            
            super().complete_infos(n_nodes=self.n_nodes, node_types=self.node_types)
            self.valency_distribution = torch.zeros(3 * self.max_n_nodes - 2)
            self.valency_distribution[0:6] = torch.tensor(
                [2.6071e-06, 0.163, 0.352, 0.320, 0.16313, 0.00073]
            )
        else:
            self.max_n_nodes = 84
            self.max_weight = 500
            self.n_nodes = torch.tensor(
                [
                    0.0,
                    0.0,
                    0.0,
                    0.0,
                    0.0,
                    0.0,
                    1.3358982579886716e-05,
                    2.226497096647786e-05,
                    5.788892451284243e-05,
                    0.0002983506109508033,
                    0.0007926329664066118,
                    0.002912258202415304,
                    0.004689002885540237,
                    0.0071515086744326885,
                    0.011274981297424389,
                    0.017117309679028178,
                    0.02535980193081828,
                    0.03501389334188308,
                    0.046707456093477255,
                    0.058178369135406645,
                    0.07082932563855937,
                    0.08147198176053579,
                    0.074921627302198,
                    0.0843842399629511,
                    0.09309874959923052,
                    0.09145114174771116,
                    0.07717484236400556,
                    0.06339727832994906,
                    0.040330768408678,
                    0.031130882405329345,
                    0.024393502190873145,
                    0.01923693491503687,
                    0.015028855402372556,
                    0.010362117487798797,
                    0.006915499982188023,
                    0.004119019628798404,
                    0.0015941719211998147,
                    0.0005610772683552421,
                    8.905988386591144e-06,
                ]
            )

            self.node_types = torch.tensor(
                [
                    0.47133961724059564,
                    0.38950764761241863,
                    0.0645558853033547,
                    0.052731537103237046,
                    0.007266642702469922,
                    1.2914227665631326e-05,
                    0.009413353413093095,
                    0.0039242982210397175,
                    0.0011660429026912945,
                    8.206127343436598e-05,
                ]
            )
            self.edge_types = torch.tensor(
                [
                    0.8594905404109199,
                    0.13913510512550323,
                    0.001374354463576918,
                    0.0,
                    0.0,
                ]
            )

            super().complete_infos(n_nodes=self.n_nodes, node_types=self.node_types)
            
            self.valency_distribution = torch.zeros(3 * self.max_n_nodes - 2)
            self.valency_distribution[0:6] = torch.tensor(
                [0, 0.5136, 0.0840, 0.0554, 0.3456, 0.0012]
            )

        if recompute_statistics:
            np.set_printoptions(suppress=True, precision=5)
            self.n_nodes = datamodule.node_counts()
            print("Distribution of number of nodes", self.n_nodes)
            np.savetxt("n_counts.txt", self.n_nodes.numpy())
            self.node_types = datamodule.node_types()  # There are no node types
            print("Distribution of node types", self.node_types)
            np.savetxt("atom_types.txt", self.node_types.numpy())

            self.edge_types = datamodule.edge_counts()
            print("Distribution of edge types", self.edge_types)
            np.savetxt("edge_types.txt", self.edge_types.numpy())

            valencies = datamodule.valency_count(self.max_n_nodes)
            print("Distribution of the valencies", valencies)
            np.savetxt("valencies.txt", valencies.numpy())
            self.valency_distribution = valencies
            assert False
