import os
import pathlib

import torch
from torch.utils.data import random_split
import torch_geometric.utils
from torch_geometric.data import InMemoryDataset, download_url

from src.datasets.abstract_dataset import AbstractDataModule, AbstractDatasetInfos

import pdb

class SpectreGraphDataset(InMemoryDataset):
    """
    Data: 
        This Dataset contains three abstract graph datasets from the SPECTRE paper, which are
        1. community_12_21_100
        2. sbm_200
        3. planar_64_200
    
    Args:
        dataset_name: Choose one from ['comm20', 'sbm', 'comm20']
        split: Get the traing, val, or test set. Choose one from ['train', 'val', 'test']
        root (str): Root directory where the dataset should be saved.
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj
    """
    def __init__(self, dataset_name, split, root, transform=None, pre_transform=None, pre_filter=None):
        self.sbm_file = 'sbm_200.pt'
        self.planar_file = 'planar_64_200.pt'
        self.comm20_file = 'community_12_21_100.pt'
        self.dataset_name = dataset_name
        self.split = split
        self.num_graphs = 200
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['train.pt', 'val.pt', 'test.pt']

    @property
    def processed_file_names(self):
            return [self.split + '.pt']

    def download(self):
        """=
        Download raw qm9 files. Taken from PyG QM9 class
        """
        if self.dataset_name == 'sbm':
            raw_url = 'https://raw.githubusercontent.com/KarolisMart/SPECTRE/main/data/sbm_200.pt'
        elif self.dataset_name == 'planar':
            raw_url = 'https://raw.githubusercontent.com/KarolisMart/SPECTRE/main/data/planar_64_200.pt'
        elif self.dataset_name == 'comm-20':
            raw_url = 'https://raw.githubusercontent.com/KarolisMart/SPECTRE/main/data/community_12_21_100.pt'
        elif self.dataset_name == 'protein':
            raw_url = None
        else:
            raise ValueError(f'Unknown dataset {self.dataset_name}')
        print(self.raw_dir)

        if self.dataset_name == 'protein':
            adjs, eigvals, eigvecs, n_nodes, max_eigval, min_eigval, n_max = self.load_protein_dataset()
        else:
            file_path = download_url(raw_url, self.raw_dir)
            adjs, eigvals, eigvecs, n_nodes, max_eigval, min_eigval, same_sample, n_max = torch.load(file_path)


        g_cpu = torch.Generator()
        # g_cpu.manual_seed(0) # SPECTRE actually used 1234
        g_cpu.manual_seed(1234) 
        
        # for protein, we use all data
        if self.dataset_name == "protein":
            self.num_graphs = len(adjs)
        test_len = int(round(self.num_graphs * 0.2))
        train_len = int(round((self.num_graphs - test_len) * 0.8))
        val_len = self.num_graphs - train_len - test_len
        indices = torch.randperm(self.num_graphs, generator=g_cpu)
        print(f'Dataset sizes: train {train_len}, val {val_len}, test {test_len}')
        train_indices = indices[:train_len]
        val_indices = indices[train_len:train_len + val_len]
        test_indices = indices[train_len + val_len:]

        train_data = []
        val_data = []
        test_data = []

        for i, adj in enumerate(adjs):
            if i in train_indices:
                train_data.append(adj)
            elif i in val_indices:
                val_data.append(adj)
            elif i in test_indices:
                test_data.append(adj)
            else:
                raise ValueError(f'Index {i} not in any split')

        torch.save(train_data, self.raw_paths[0])
        torch.save(val_data, self.raw_paths[1])
        torch.save(test_data, self.raw_paths[2])


    def process(self):
        file_idx = {'train': 0, 'val': 1, 'test': 2}
        raw_dataset = torch.load(self.raw_paths[file_idx[self.split]])

        data_list = []
        for adj in raw_dataset:
            n = adj.shape[-1]
            X = torch.ones(n, 1, dtype=torch.float)
            y = torch.zeros([1, 0]).float()
            edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)

            # torch.set_printoptions(threshold = 1e4)
            # print("hi")
            # print("dense: ", adj)
            # print("sparse: ", torch_geometric.utils.dense_to_sparse(adj))
            # print("edge_index:", edge_index)
            # pdb.set_trace()

            edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
            edge_attr[:, 1] = 1
            num_nodes = n * torch.ones(1, dtype=torch.long)
            data = torch_geometric.data.Data(x=X, edge_index=edge_index, edge_attr=edge_attr,
                                             y=y, n_nodes=num_nodes)
            
            data_list.append(data)
            
            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)
        torch.save(self.collate(data_list), self.processed_paths[0])

    def load_protein_dataset(self):
        """ Adapted from https://github.com/KarolisMart/SPECTRE/blob/main/data.py"""
        min_num_nodes=100
        max_num_nodes=500
        filename = os.path.join(self.raw_dir, f'proteins_{min_num_nodes}_{max_num_nodes}.pt')
        # self.k = k
        # self.ignore_first_eigv = ignore_first_eigv
        if os.path.isfile(filename):
            adjs, eigvals, eigvecs, n_nodes, max_eigval, min_eigval, n_max = torch.load(filename)
            print(f'Dataset {filename} loaded from file')
        else:
            adjs = []
            eigvals = []
            eigvecs = []
            n_nodes = []
            n_max = 0
            max_eigval = 0
            min_eigval = 0

            import networkx as nx
            import numpy as np

            path = os.path.join(self.raw_dir, 'DD')
            dd_file_list = ['DD_A.txt', 'DD_graph_indicator.txt', 'DD_graph_labels.txt', 'DD_node_labels.txt']
            if not all([os.path.exists(os.path.join(path, dd_file)) for dd_file in dd_file_list]):
                # Downlad the raw txt files
                if not os.path.exists(path):
                    os.mkdir(path)
                print(f'Downloading the raw Protein dataset in txt')
                dd_dir_url = 'https://raw.githubusercontent.com/KarolisMart/SPECTRE/main/data/DD'
                for dd_file in dd_file_list:
                    raw_url = os.path.join(dd_dir_url, dd_file)
                    file_path = download_url(raw_url, path)

            # Load data
            data_adj = np.loadtxt(os.path.join(path, 'DD_A.txt'), delimiter=',').astype(int)
            data_node_label = np.loadtxt(os.path.join(path, 'DD_node_labels.txt'), delimiter=',').astype(int)
            data_graph_indicator = np.loadtxt(os.path.join(path, 'DD_graph_indicator.txt'), delimiter=',').astype(int)
            data_graph_types = np.loadtxt(os.path.join(path, 'DD_graph_labels.txt'), delimiter=',').astype(int)

            data_tuple = list(map(tuple, data_adj))

            print(f'Converting the raw Protein dataset from .txt to .pt, with min_num_nodes={min_num_nodes}, max_num_nodes={max_num_nodes}')
            G = nx.Graph()
            # Add edges
            G.add_edges_from(data_tuple)
            G.remove_nodes_from(list(nx.isolates(G)))

            # remove self-loop
            G.remove_edges_from(nx.selfloop_edges(G))

            # Split into graphs
            graph_num = data_graph_indicator.max()
            node_list = np.arange(data_graph_indicator.shape[0]) + 1

            for i in range(graph_num):
                # Find the nodes for each graph
                nodes = node_list[data_graph_indicator == i + 1]
                G_sub = G.subgraph(nodes)
                G_sub.graph['label'] = data_graph_types[i]
                if G_sub.number_of_nodes() >= min_num_nodes and G_sub.number_of_nodes() <= max_num_nodes:
                    adj = torch.from_numpy(nx.adjacency_matrix(G_sub).toarray()).float()
                    L = nx.normalized_laplacian_matrix(G_sub).toarray()
                    L = torch.from_numpy(L).float()
                    eigval, eigvec = torch.linalg.eigh(L)
                    
                    eigvals.append(eigval)
                    eigvecs.append(eigvec)
                    adjs.append(adj)
                    n_nodes.append(G_sub.number_of_nodes())
                    if G_sub.number_of_nodes() > n_max:
                        n_max = G_sub.number_of_nodes()
                    max_eigval = torch.max(eigval)
                    if max_eigval > max_eigval:
                        max_eigval = max_eigval
                    min_eigval = torch.min(eigval)
                    if min_eigval < min_eigval:
                        min_eigval = min_eigval

            torch.save([adjs, eigvals, eigvecs, n_nodes, max_eigval, min_eigval, n_max], filename)
            print(f'Dataset {filename} saved')
        
        return adjs, eigvals, eigvecs, n_nodes, max_eigval, min_eigval, n_max

        # self.max_k_eigval = 0
        # for eigv in self.eigvals:
        #     last_idx = self.k if self.k < len(eigv) else len(eigv) - 1
        #     if eigv[last_idx] > self.max_k_eigval:
        #         self.max_k_eigval = eigv[last_idx].item()

class PlanarDataset(SpectreGraphDataset):
    """
    Data: 
        This Dataset contains three abstract graph datasets from the SPECTRE paper, which are
        1. community_12_21_100
        2. sbm_200
        3. planar_64_200
    
    Args:
        dataset_name: Choose one from ['comm20', 'sbm', 'comm20']
        split: Get the traing, val, or test set. Choose one from ['train', 'val', 'test']
        root (str): Root directory where the dataset should be saved.
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj
    """
    def __init__(self, dataset_name, split, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(dataset_name, split, root, transform, pre_transform, pre_filter)

class SpectreGraphDataModule(AbstractDataModule):
    def __init__(self, cfg, n_graphs=200):
        self.cfg = cfg
        self.datadir = cfg.dataset.datadir
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)


        datasets = {'train': SpectreGraphDataset(dataset_name=self.cfg.dataset.name,
                                                 split='train', root=root_path),
                    'val': SpectreGraphDataset(dataset_name=self.cfg.dataset.name,
                                        split='val', root=root_path),
                    'test': SpectreGraphDataset(dataset_name=self.cfg.dataset.name,
                                        split='test', root=root_path)}
        # print(f'Dataset sizes: train {train_len}, val {val_len}, test {test_len}')

        super().__init__(cfg, datasets)
        self.inner = self.train_dataset

    def __getitem__(self, item):
        return self.inner[item]

class PlanarDataModule(SpectreGraphDataModule):
    def __init__(self, cfg, n_graphs=200):
        super().__init__(cfg, n_graphs)

class SpectreDatasetInfos(AbstractDatasetInfos):
    def __init__(self, datamodule, dataset_config):
        self.datamodule = datamodule
        self.name = 'nx_graphs'
        self.n_nodes = self.datamodule.node_counts()
        self.node_types = torch.tensor([1])               # There are no node types
        self.edge_types = self.datamodule.edge_counts()     # 
        super().complete_infos(self.n_nodes, self.node_types)

