from ogb.nodeproppred import PygNodePropPredDataset
import pandas as pd
from sklearn import preprocessing
import os
import numpy as np
import os.path
import torch_geometric as tg
import torch_geometric.transforms as T
from torch_sparse import SparseTensor
import torch
import pickle
import scipy.sparse as sp
from torch_scatter import scatter
import random

from ipdb import set_trace as stc


class OGBNDataset(object):

    def __init__(self, args, dataset_name='ogbn-arxiv'):
        """
        download the corresponding dataset based on the input name of dataset appointed
        the dataset will be divided into training, validation and test dataset
        the graph object will be obtained, which has three attributes
            edge_attr=[79122504, 8]
            edge_index=[2, 79122504]
            x=[132534, 8]
            y=[132534, 112]
        :param dataset_name:
        """
        self.dataset_name = dataset_name

        if args.norm_feats == False:
            self.dataset = PygNodePropPredDataset(name=self.dataset_name)
        else:
            transform = T.NormalizeFeatures()
            self.dataset = PygNodePropPredDataset(name=self.dataset_name, transform=transform)

        self.num_tasks = self.dataset.num_tasks
        self.num_classes = self.dataset.num_classes
        self.num_features = self.dataset.num_features

        self.splitted_idx = self.dataset.get_idx_split()
        self.whole_graph = self.dataset[0]
        self.origin_whole_graph = self.whole_graph.clone()
        self.idx_mapping()
        self.length = 1
        
        self.total_no_of_edges = self.whole_graph.edge_index.shape[1]
        self.total_no_of_nodes = self.whole_graph.y.shape[0]
        self.x = self.whole_graph.x
        self.y = self.whole_graph.y.view(-1)
        self.edge_index = self.whole_graph.edge_index
        self.edge_attr = torch.ones(self.total_no_of_edges, 1)
        self.edge_index_array = self.edge_index.t().numpy()
        self.adj = self.construct_adj()
        # symmetrize the adjacency matrix
        rows, cols = self.adj.nonzero()
        self.adj[cols, rows] = self.adj[rows, cols]
        if args.self_loop == True:
            identity = sp.csr_matrix(
                (np.ones(self.total_no_of_nodes), (np.arange(self.total_no_of_nodes), np.arange(self.total_no_of_nodes))),
                shape=(self.total_no_of_nodes, self.total_no_of_nodes))
            self.adj = self.adj + identity
            coo = self.adj.tocoo()
            edge_index = np.vstack((coo.row, coo.col))
            edge_attr = coo.data
            self.edge_index, self.whole_graph.edge_index = torch.from_numpy(edge_index).long(), torch.from_numpy(edge_index).long()
            self.edge_attr, self.whole_graph.edge_attr = torch.from_numpy(edge_attr).float().view(-1,1), torch.from_numpy(edge_attr).float().view(-1,1)
            self.total_no_of_edges = self.whole_graph.edge_index.shape[1]
            
        if args.norm_adj == True:
            deg = self.adj.sum(dim=1).to(torch.float)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            adj_dict = {
                'DAD': deg_inv_sqrt.view(-1,1)*self.adj*deg_inv_sqrt.view(1,-1),
                'DA': deg_inv_sqrt.view(-1,1)*deg_inv_sqrt.view(-1,1)*self.adj,
                'AD': self.adj*deg_inv_sqrt.view(1,-1)*deg_inv_sqrt.view(1,-1),
            }
            self.adj = adj_dict[args.adjacency]
            row, col, edge_attr = self.adj.t().coo()
            edge_index = torch.stack([row, col], dim=0)

            self.edge_index = edge_index
            self.whole_graph.edge_index = edge_index
            self.edge_attr = edge_attr
            self.whole_graph.edge_attr = edge_attr
            self.edge_index_array = self.edge_index.t().numpy()
            self.edge_index_dict = self.edge_features_index(args)
        else:
            self.edge_index_dict = self.edge_features_index(args)

    def idx_mapping(self):
        train_idx, valid_idx, test_idx = self.splitted_idx["train"], self.splitted_idx["valid"], self.splitted_idx["test"]
        
        feats_dim = self.num_features
        train_num, valid_num, test_num = len(train_idx), len(valid_idx), len(test_idx)
        nnum = train_num + valid_num + test_num
        enum = self.whole_graph.edge_index.shape[1]
        
        x = torch.zeros((nnum, feats_dim))
        y = torch.zeros((nnum, 1)).long()
        for i in range(0, train_num):
            x[i] = self.whole_graph.x[train_idx[i]]
            y[i] = self.whole_graph.y[train_idx[i]]
        for i in range(train_num, train_num + valid_num):
            x[i] = self.whole_graph.x[valid_idx[i - train_num]]
            y[i] = self.whole_graph.y[valid_idx[i - train_num]]
        for i in range(train_num + valid_num, nnum):
            x[i] = self.whole_graph.x[test_idx[i - train_num - valid_num]]
            y[i] = self.whole_graph.y[test_idx[i - train_num - valid_num]]
        self.whole_graph.x = x
        self.whole_graph.y = y
        
        idx_mapper = torch.zeros(nnum).long()
        new_train_idx = torch.LongTensor(range(0, train_num))
        for i in range(0, train_num):
            idx_mapper[train_idx[i]] = i
        new_valid_idx = torch.LongTensor(range(train_num, train_num + valid_num))
        for i in range(0, valid_num):
            idx_mapper[valid_idx[i]] = i + train_num
        new_test_idx = torch.LongTensor(range(train_num + valid_num, nnum))
        for i in range(0, test_num):
            idx_mapper[test_idx[i]] = i + train_num + valid_num
    
        self.train_idx = new_train_idx
        self.valid_idx = new_valid_idx
        self.test_idx = new_test_idx
        self.idx_mapper = idx_mapper
        
        for i in range(enum):
            self.whole_graph.edge_index[0, i] = idx_mapper[self.whole_graph.edge_index[0, i]]
            self.whole_graph.edge_index[1, i] = idx_mapper[self.whole_graph.edge_index[1, i]]

    def generate_one_hot_encoding(self):

        le = preprocessing.LabelEncoder()
        species_unique = torch.unique(self.species)
        max_no = species_unique.max()
        le.fit(species_unique % max_no)
        species = le.transform(self.species.squeeze() % max_no)
        species = np.expand_dims(species, axis=1)

        enc = preprocessing.OneHotEncoder()
        enc.fit(species)
        one_hot_encoding = enc.transform(species).toarray()

        return torch.FloatTensor(one_hot_encoding)

    def extract_node_features(self, aggr='add'):

        file_path = 'init_node_features_{}.pt'.format(aggr)

        if os.path.isfile(file_path):
            print('{} exists'.format(file_path))
        else:
            if aggr in ['add', 'mean', 'max']:
                node_features = scatter(self.edge_attr,
                                        self.edge_index[0],
                                        dim=0,
                                        dim_size=self.total_no_of_nodes,
                                        reduce=aggr)
            else:
                raise Exception('Unknown Aggr Method')
            torch.save(node_features, file_path)
            print('Node features extracted are saved into file {}'.format(file_path))
        return file_path

    def construct_adj(self):

        adj = sp.csr_matrix((np.ones(self.total_no_of_edges, dtype=np.uint8),
                             (self.edge_index_array[:, 0], self.edge_index_array[:, 1])),
                            shape=(self.total_no_of_nodes, self.total_no_of_nodes))
        
        return adj

    def edge_features_index(self, args):
        # file_name = 'edge_features_index_v2.pkl'
        file_name = 'edge_features_index'
        if args.norm_adj == True:
            file_name = file_name + '_normadj'
        if args.norm_feats == True:
            file_name = file_name + '_normfeats'
        if args.self_loop == True:
            file_name = file_name + '_selfloop'
        file_name = file_name + '_v2.pkl'
        if os.path.isfile(file_name):
            print('{} exists'.format(file_name))
            with open(file_name, 'rb') as edge_features_index:
                edge_index_dict = pickle.load(edge_features_index)
        else:
            df = pd.DataFrame()
            df['1st_index'] = self.whole_graph.edge_index[0]
            df['2nd_index'] = self.whole_graph.edge_index[1]
            df_reset = df.reset_index()
            key = zip(df_reset['1st_index'], df_reset['2nd_index'])
            edge_index_dict = df_reset.set_index(key)['index'].to_dict()
            with open(file_name, 'wb') as edge_features_index:
                pickle.dump(edge_index_dict, edge_features_index)
            print('Edges\' indexes information is saved into file {}'.format(file_name))
        return edge_index_dict

    @staticmethod
    def random_partition_graph(num_nodes, cluster_number=100):
        parts = np.random.randint(cluster_number, size=num_nodes)
        return parts

    def generate_sub_graphs(self, parts, cluster_number=10, batch_size=1):

        no_of_batches = cluster_number // batch_size

        print('The number of clusters: {}'.format(cluster_number))

        sg_nodes = [[] for _ in range(no_of_batches)]
        sg_edges = [[] for _ in range(no_of_batches)]
        sg_edges_orig = [[] for _ in range(no_of_batches)]
        sg_edges_index = [[] for _ in range(no_of_batches)]

        edges_no = 0

        for cluster in range(no_of_batches):
            sg_nodes[cluster] = np.where(parts == cluster)[0]
            sg_edges[cluster] = tg.utils.from_scipy_sparse_matrix(self.adj[sg_nodes[cluster], :][:, sg_nodes[cluster]])[0]
            edges_no += sg_edges[cluster].shape[1]
            # mapper
            mapper = {nd_idx: nd_orig_idx for nd_idx, nd_orig_idx in enumerate(sg_nodes[cluster])}
            # map edges to original edges
            sg_edges_orig[cluster] = OGBNDataset.edge_list_mapper(mapper, sg_edges[cluster])
            # edge index
            sg_edges_index[cluster] = [self.edge_index_dict[(edge[0], edge[1])] for edge in
                                       sg_edges_orig[cluster].t().numpy()]
        print('Total number edges of sub graphs: {}, of whole graph: {}, {:.2f} % edges are lost'.
              format(edges_no, self.total_no_of_edges, (1 - edges_no / self.total_no_of_edges) * 100))

        return sg_nodes, sg_edges, sg_edges_index, sg_edges_orig

    @staticmethod
    def edge_list_mapper(mapper, sg_edges_list):
        idx_1st = list(map(lambda x: mapper[x], sg_edges_list[0].tolist()))
        idx_2nd = list(map(lambda x: mapper[x], sg_edges_list[1].tolist()))
        sg_edges_orig = torch.LongTensor([idx_1st, idx_2nd])
        return sg_edges_orig


class OGBNDatasetInductive(object):

    def __init__(self, args, dataset_name='ogbn-arxiv'):
        """
        download the corresponding dataset based on the input name of dataset appointed
        the dataset will be divided into training, validation and test dataset
        the graph object will be obtained, which has three attributes
            edge_attr=[79122504, 8]
            edge_index=[2, 79122504]
            x=[132534, 8]
            y=[132534, 112]
        :param dataset_name:
        """
        self.dataset_name = dataset_name

        if args.norm_feats == False:
            self.dataset = PygNodePropPredDataset(name=self.dataset_name)
        else:
            transform = T.NormalizeFeatures()
            self.dataset = PygNodePropPredDataset(name=self.dataset_name, transform=transform)
        self.num_tasks = self.dataset.num_tasks
        self.num_classes = self.dataset.num_classes
        self.num_features = self.dataset.num_features  
        self.splitted_idx = self.dataset.get_idx_split()
        self.whole_graph = self.dataset[0]
        self.origin_whole_graph = self.whole_graph.clone()
        self.idx_mapping()
        
        self.total_no_of_edges = self.whole_graph.edge_index.shape[1]
        self.whole_graph.edge_attr = torch.ones(self.total_no_of_edges, 1)
        self.length = 1

        # self.train_idx, self.valid_idx, self.test_idx = self.splitted_idx["train"], self.splitted_idx["valid"], self.splitted_idx["test"]

        self.obs_idx = torch.cat((self.train_idx, self.valid_idx), dim=0)
        self.obs_num = len(self.obs_idx)
        self.inductive()

        self.total_no_of_edges = self.whole_graph.edge_index.shape[1]
        self.total_no_of_nodes = self.whole_graph.y.shape[0]
        self.x = self.whole_graph.x
        self.y = self.whole_graph.y.view(-1)

        self.edge_index = self.whole_graph.edge_index
        self.edge_attr = self.whole_graph.edge_attr
        self.edge_index_array = self.edge_index.t().numpy()
        self.adj = self.construct_adj()
        # symmetrize the adjacency matrix
        rows, cols = self.adj.nonzero()
        self.adj[cols, rows] = self.adj[rows, cols]
        if args.self_loop == True:
            identity = sp.csr_matrix(
                (np.ones(self.total_no_of_nodes), (np.arange(self.total_no_of_nodes), np.arange(self.total_no_of_nodes))),
                shape=(self.total_no_of_nodes, self.total_no_of_nodes))
            self.adj = self.adj + identity
            coo = self.adj.tocoo()
            edge_index = np.vstack((coo.row, coo.col))
            edge_attr = coo.data
            self.edge_index, self.whole_graph.edge_index = torch.from_numpy(edge_index).long(), torch.from_numpy(edge_index).long()
            self.edge_attr, self.whole_graph.edge_attr = torch.from_numpy(edge_attr).float().view(-1,1), torch.from_numpy(edge_attr).float().view(-1,1)
            self.total_no_of_edges = self.whole_graph.edge_index.shape[1]

        if args.norm_adj == True:
            deg = self.adj.sum(dim=1).to(torch.float)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            adj_dict = {
                'DAD': deg_inv_sqrt.view(-1,1)*self.adj*deg_inv_sqrt.view(1,-1),
                'DA': deg_inv_sqrt.view(-1,1)*deg_inv_sqrt.view(-1,1)*self.adj,
                'AD': self.adj*deg_inv_sqrt.view(1,-1)*deg_inv_sqrt.view(1,-1),
            }
            self.adj = adj_dict[args.adjacency]
            row, col, edge_attr = self.adj.t().coo()
            edge_index = torch.stack([row, col], dim=0)

            self.edge_index = edge_index
            self.edge_attr = edge_attr
            self.whole_graph.edge_index = edge_index
            self.whole_graph.edge_attr = edge_attr
            self.edge_index_array = self.edge_index.t().numpy()
            self.edge_index_dict = self.edge_features_index(args)
        else:
            self.edge_index_dict = self.edge_features_index(args)

    def idx_mapping(self):
        train_idx, valid_idx, test_idx = self.splitted_idx["train"], self.splitted_idx["valid"], self.splitted_idx["test"]
        
        feats_dim = self.num_features
        train_num, valid_num, test_num = len(train_idx), len(valid_idx), len(test_idx)
        nnum = train_num + valid_num + test_num
        enum = self.whole_graph.edge_index.shape[1]
        
        x = torch.zeros((nnum, feats_dim))
        y = torch.zeros((nnum, 1)).long()
        for i in range(0, train_num):
            x[i] = self.whole_graph.x[train_idx[i]]
            y[i] = self.whole_graph.y[train_idx[i]]
        for i in range(train_num, train_num + valid_num):
            x[i] = self.whole_graph.x[valid_idx[i - train_num]]
            y[i] = self.whole_graph.y[valid_idx[i - train_num]]
        for i in range(train_num + valid_num, nnum):
            x[i] = self.whole_graph.x[test_idx[i - train_num - valid_num]]
            y[i] = self.whole_graph.y[test_idx[i - train_num - valid_num]]
        self.whole_graph.x = x
        self.whole_graph.y = y
        
        idx_mapper = torch.zeros(nnum).long()
        new_train_idx = torch.LongTensor(range(0, train_num))
        for i in range(0, train_num):
            idx_mapper[train_idx[i]] = i
        new_valid_idx = torch.LongTensor(range(train_num, train_num + valid_num))
        for i in range(0, valid_num):
            idx_mapper[valid_idx[i]] = i + train_num
        new_test_idx = torch.LongTensor(range(train_num + valid_num, nnum))
        for i in range(0, test_num):
            idx_mapper[test_idx[i]] = i + train_num + valid_num
    
        self.train_idx = new_train_idx
        self.valid_idx = new_valid_idx
        self.test_idx = new_test_idx
        self.idx_mapper = idx_mapper
        
        for i in range(enum):
            self.whole_graph.edge_index[0, i] = idx_mapper[self.whole_graph.edge_index[0, i]]
            self.whole_graph.edge_index[1, i] = idx_mapper[self.whole_graph.edge_index[1, i]]

    def inductive(self):
        self.whole_graph.num_nodes = self.obs_num
        self.whole_graph.x = self.whole_graph.x[self.obs_idx]
        self.whole_graph.y = self.whole_graph.y[self.obs_idx]

        mask, self.whole_graph.edge_index = self.delete_columns(self.whole_graph.edge_index, self.test_idx)
        self.whole_graph.edge_attr = self.whole_graph.edge_attr[mask]
    
    def delete_columns(self, edge_index: torch.Tensor, ind_idx: torch.Tensor) -> torch.Tensor:
        mask0 = edge_index[0] >= self.obs_num
        mask1 = edge_index[1] >= self.obs_num
        mask = ~(mask0 + mask1)

        return mask, edge_index[:, mask]

    def generate_one_hot_encoding(self):

        le = preprocessing.LabelEncoder()
        species_unique = torch.unique(self.species)
        max_no = species_unique.max()
        le.fit(species_unique % max_no)
        species = le.transform(self.species.squeeze() % max_no)
        species = np.expand_dims(species, axis=1)

        enc = preprocessing.OneHotEncoder()
        enc.fit(species)
        one_hot_encoding = enc.transform(species).toarray()

        return torch.FloatTensor(one_hot_encoding)

    def extract_node_features(self, aggr='add'):

        file_path = 'init_node_features_ind_{}.pt'.format(aggr)

        if os.path.isfile(file_path):
            print('{} exists'.format(file_path))
        else:
            if aggr in ['add', 'mean', 'max']:
                node_features = scatter(self.edge_attr,
                                        self.edge_index[0],
                                        dim=0,
                                        dim_size=self.total_no_of_nodes,
                                        reduce=aggr)
            else:
                raise Exception('Unknown Aggr Method')
            torch.save(node_features, file_path)
            print('Node features extracted are saved into file {}'.format(file_path))
        return file_path

    def construct_adj(self):
        adj = sp.csr_matrix((np.ones(self.total_no_of_edges, dtype=np.uint8),
                             (self.edge_index_array[:, 0], self.edge_index_array[:, 1])),
                            shape=(self.total_no_of_nodes, self.total_no_of_nodes))
        
        return adj

    def edge_features_index(self, args):
        file_name = 'edge_features_index'
        if args.norm_adj == True:
            file_name = file_name + '_normadj'
        if args.norm_feats == True:
            file_name = file_name + '_normfeats'
        if args.self_loop == True:
            file_name = file_name + '_selfloop'
        file_name = file_name + '_ind_v2.pkl'
        if os.path.isfile(file_name):
            print('{} exists'.format(file_name))
            with open(file_name, 'rb') as edge_features_index:
                edge_index_dict = pickle.load(edge_features_index)
        else:
            df = pd.DataFrame()
            df['1st_index'] = self.whole_graph.edge_index[0]
            df['2nd_index'] = self.whole_graph.edge_index[1]
            df_reset = df.reset_index()
            key = zip(df_reset['1st_index'], df_reset['2nd_index'])
            edge_index_dict = df_reset.set_index(key)['index'].to_dict()
            with open(file_name, 'wb') as edge_features_index:
                pickle.dump(edge_index_dict, edge_features_index)
            print('Edges\' indexes information is saved into file {}'.format(file_name))
        return edge_index_dict

    @staticmethod
    def random_partition_graph(num_nodes, cluster_number=100):
        parts = np.random.randint(cluster_number, size=num_nodes)
        return parts

    def generate_sub_graphs(self, parts, cluster_number=10, batch_size=1):

        no_of_batches = cluster_number // batch_size

        print('The number of clusters: {}'.format(cluster_number))

        sg_nodes = [[] for _ in range(no_of_batches)]
        sg_edges = [[] for _ in range(no_of_batches)]
        sg_edges_orig = [[] for _ in range(no_of_batches)]
        sg_edges_index = [[] for _ in range(no_of_batches)]

        edges_no = 0

        for cluster in range(no_of_batches):
            sg_nodes[cluster] = np.where(parts == cluster)[0]
            sg_edges[cluster] = tg.utils.from_scipy_sparse_matrix(self.adj[sg_nodes[cluster], :][:, sg_nodes[cluster]])[0]
            edges_no += sg_edges[cluster].shape[1]
            # mapper
            mapper = {nd_idx: nd_orig_idx for nd_idx, nd_orig_idx in enumerate(sg_nodes[cluster])}
            # map edges to original edges
            sg_edges_orig[cluster] = OGBNDataset.edge_list_mapper(mapper, sg_edges[cluster])
            # edge index
            sg_edges_index[cluster] = [self.edge_index_dict[(edge[0], edge[1])] for edge in
                                       sg_edges_orig[cluster].t().numpy()]
        print('Total number edges of sub graphs: {}, of whole graph: {}, {:.2f} % edges are lost'.
              format(edges_no, self.total_no_of_edges, (1 - edges_no / self.total_no_of_edges) * 100))

        return sg_nodes, sg_edges, sg_edges_index, sg_edges_orig

    @staticmethod
    def edge_list_mapper(mapper, sg_edges_list):
        idx_1st = list(map(lambda x: mapper[x], sg_edges_list[0].tolist()))
        idx_2nd = list(map(lambda x: mapper[x], sg_edges_list[1].tolist()))
        sg_edges_orig = torch.LongTensor([idx_1st, idx_2nd])
        return sg_edges_orig