import hashlib
import os.path as osp
import pickle
import shutil

import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import decide_download
from torch_geometric.data import Data, download_url
from torch_geometric.data import InMemoryDataset
from tqdm import tqdm

from graphgps.transform.posenc_stats import compute_posenc_stats
from torch_geometric.graphgym.config import cfg


class PeptidesFunctional_LG_Dataset(InMemoryDataset):
    def __init__(self, root='datasets', smiles2graph=smiles2graph,
                 transform=None, pre_transform=None):
        """
        PyG dataset of 15,535 peptides represented as their molecular graph
        (SMILES) with 10-way multi-task binary classification of their
        functional classes.

        The goal is use the molecular representation of peptides instead
        of amino acid sequence representation ('peptide_seq' field in the file,
        provided for possible baseline benchmarking but not used here) to test
        GNNs' representation capability.

        The 10 classes represent the following functional classes (in order):
            ['antifungal', 'cell_cell_communication', 'anticancer',
            'drug_delivery_vehicle', 'antimicrobial', 'antiviral',
            'antihypertensive', 'antibacterial', 'antiparasitic', 'toxic']

        Args:
            root (string): Root directory where the dataset should be saved.
            smiles2graph (callable): A callable function that converts a SMILES
                string into a graph object. We use the OGB featurization.
                * The default smiles2graph requires rdkit to be installed *
        """

        self.original_root = root
        self.smiles2graph = smiles2graph
        self.folder = osp.join(root, 'peptides-functional_lg')

        self.url = 'https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1'
        self.version = '701eb743e899f4d793f0e13c8fa5a1b4'  # MD5 hash of the intended dataset file
        self.url_stratified_split = 'https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1'
        self.md5sum_stratified_split = '5a0114bdadc80b94fc7ae974f13ef061'

        # Check version and update if necessary.
        release_tag = osp.join(self.folder, self.version)
        if osp.isdir(self.folder) and (not osp.exists(release_tag)):
            print(f"{self.__class__.__name__} has been updated.")
            if input("Will you update the dataset now? (y/N)\n").lower() == 'y':
                shutil.rmtree(self.folder)

        super().__init__(self.folder, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return 'peptide_multi_class_dataset.csv.gz'

    @property
    def processed_file_names(self):
        return 'geometric_data_processed.pt'

    def _md5sum(self, path):
        hash_md5 = hashlib.md5()
        with open(path, 'rb') as f:
            buffer = f.read()
            hash_md5.update(buffer)
        return hash_md5.hexdigest()

    def download(self):
        if decide_download(self.url):
            path = download_url(self.url, self.raw_dir)
            # Save to disk the MD5 hash of the downloaded file.
            hash = self._md5sum(path)
            if hash != self.version:
                raise ValueError("Unexpected MD5 hash of the downloaded file")
            open(osp.join(self.root, hash), 'w').close()
            # Download train/val/test splits.
            path_split1 = download_url(self.url_stratified_split, self.root)
            assert self._md5sum(path_split1) == self.md5sum_stratified_split
        else:
            print('Stop download.')
            exit(-1)

    def process(self):
        data_df = pd.read_csv(osp.join(self.raw_dir,
                                       'peptide_multi_class_dataset.csv.gz'))
        smiles_list = data_df['smiles']

        print('Converting SMILES strings into graphs...')
        data_list = []
        for i in tqdm(range(len(smiles_list))):
            data = Data()

            smiles = smiles_list[i]
            graph = self.smiles2graph(smiles)

            assert (len(graph['edge_feat']) == graph['edge_index'].shape[1])
            assert (len(graph['node_feat']) == graph['num_nodes'])

            # data.__num_nodes__ = int(graph['num_nodes'])
            # data.edge_index = torch.from_numpy(graph['edge_index']).to(
            #     torch.int64)
            # data.edge_attr = torch.from_numpy(graph['edge_feat']).to(
            #     torch.int64)
            # data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
            # data.y = torch.Tensor([eval(data_df['labels'].iloc[i])])
            
            edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
            edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
            x = torch.from_numpy(graph['node_feat']).to(torch.int64)
            x_size = x.shape[0]
            
            # NOTE: line graph nodes
            lg_node_attr_edge = edge_attr
            lg_node_attr_node = x[edge_index.T]
            lg_node_attr = torch.cat([lg_node_attr_edge, lg_node_attr_node[:, 0, :], lg_node_attr_node[:, 1, :]], dim=1)
            
            # NOTE: line graph edge index
            lg_node_idx = edge_index.T
            lg_edge_idx_mask = torch.nonzero(
                (lg_node_idx[:, 1, None] == lg_node_idx[:, 0]) &
                (lg_node_idx[:, 0, None] != lg_node_idx[:, 1])
            )
            lg_edge_idx = lg_node_idx[lg_edge_idx_mask]
            
            # NOTE: line graph edge attributes
            lg_edge_attr_node = x[lg_edge_idx[:, 0, 1]]
            edgeStartMask = lg_edge_idx_mask[:, 0].T
            edgeEndMask = lg_edge_idx_mask[:, 1].T
            lg_edge_attr_start = edge_attr[edgeStartMask]
            lg_edge_attr_end = edge_attr[edgeEndMask]
            lg_edge_attr = torch.cat([lg_edge_attr_node, lg_edge_attr_start, lg_edge_attr_end], dim=1)
            
            # NOTE: line graph data wrap up
            data.__num_nodes__ = x_size
            data.edge_index = lg_edge_idx_mask.T
            data.edge_attr = lg_edge_attr
            data.x = lg_node_attr
            data.y = torch.Tensor([eval(data_df['labels'].iloc[i])])
            
            data = compute_posenc_stats(data, ['LapPE'], False, cfg)
            
            data_list.append(data)

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)

        print('Saving...')
        torch.save((data, slices), self.processed_paths[0])

    def get_idx_split(self):
        """ Get dataset splits.

        Returns:
            Dict with 'train', 'val', 'test', splits indices.
        """
        split_file = osp.join(self.root,
                              "splits_random_stratified_peptide.pickle")
        with open(split_file, 'rb') as f:
            splits = pickle.load(f)
        split_dict = replace_numpy_with_torchtensor(splits)
        return split_dict


if __name__ == '__main__':
    dataset = PeptidesFunctionalLGDataset()
    # print(dataset.folder)
    print(dataset.data.edge_index)
    print(dataset.data.edge_index.shape)
    print(dataset.data.x.shape)
    print(dataset[100])
    print(dataset[100].y)
    print(dataset.get_idx_split())
