import os.path as osp

import torch
import numpy as np
import scipy.io as sio
from scipy.special import comb
from torch_geometric.data import InMemoryDataset
from torch_geometric.data.data import Data
import networkx as nx
import networkx

class GraphCountDataset(InMemoryDataset):
    def __init__(self, root, name, transform=None, pre_transform=None):
        self.name = name
        super(GraphCountDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

        self.data.y = self.data.y / self.data.y.std(0)

        a = sio.loadmat("randomgraph.mat")
        self.train_idx = torch.from_numpy(a['train_idx'][0])
        self.val_idx = torch.from_numpy(a['val_idx'][0])
        self.test_idx = torch.from_numpy(a['test_idx'][0])

    @property
    def raw_file_names(self):
        return ["randomgraph.mat"]

    @property
    def processed_file_names(self):
        return 'data.pt'

    @property
    def processed_dir(self):
        name = 'processed'
        return osp.join(self.root, self.name, name)

    def download(self):
        pass

    @property
    def eval_metric(self):
        return 'mae'

    @property
    def task_type(self):
        return 'regression'

    @property
    def num_tasks(self):
        return 1

    def process(self):
        b = self.processed_paths[0]
        a = sio.loadmat("randomgraph.mat")
        A = a['A'][0]
        Y = a['F']

        data_list = []
        for i in range(len(A)):
            a = A[i]
            A2 = a.dot(a)
            A3 = A2.dot(a)
            tri = np.trace(A3) / 6
            tailed = ((np.diag(A3) / 2) * (a.sum(0) - 2)).sum()
            cyc4 = 1 / 8 * (np.trace(A3.dot(a)) + np.trace(A2) - 2 * A2.sum())
            cus = a.dot(np.diag(np.exp(-a.dot(a).sum(1)))).dot(a).sum()

            deg = a.sum(0)
            star = 0
            for j in range(a.shape[0]):
                star += comb(int(deg[j]), 3)

            expy = torch.tensor([[tri, tailed, star, cyc4, cus]])

            E = np.where(A[i] > 0)
            edge_index = torch.Tensor(np.vstack((E[0], E[1]))).type(torch.int64)
            x = torch.ones(A[i].shape[0], 1)
            data_list.append(Data(edge_index=edge_index, x=x, y=expy))

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def separate_data(self, seed, fold_idx):
        return {'train': self.train_idx, 'valid': self.val_idx, 'test': self.test_idx}


class GraphCountDatasetNnodes(InMemoryDataset):
    def __init__(self, root, name, n_nodes, transform=None, pre_transform=None):
        self.name = name
        self.n_nodes = n_nodes
        super(GraphCountDatasetNnodes, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

        self.data.y = self.data.y / self.data.y.std(0)

        a = sio.loadmat("randomgraph.mat")
        self.train_idx = torch.from_numpy(a['train_idx'][0])
        self.val_idx = torch.from_numpy(a['val_idx'][0])
        self.test_idx = torch.from_numpy(a['test_idx'][0])

    @property
    def raw_file_names(self):
        return ["randomgraph.mat"]

    @property
    def processed_file_names(self):
        return f'data{self.n_nodes}.pt'

    @property
    def processed_dir(self):
        name = f'processed{self.n_nodes}'
        return osp.join(self.root, self.name, name)

    def download(self):
        pass

    @property
    def eval_metric(self):
        return 'mae'

    @property
    def task_type(self):
        return 'regression'

    @property
    def num_tasks(self):
        return 1

    def process(self):
        b = self.processed_paths[0]

        A = [networkx.generators.random_graphs.erdos_renyi_graph(self.n_nodes, 0.1, seed=None, directed=False) for i in range(5000)]
        A_arrays = [networkx.convert_matrix.to_numpy_array(a) for a in A]

        data_list = []
        for i in range(len(A_arrays)):
            a = A_arrays[i]
            A2 = a.dot(a)
            A3 = A2.dot(a)
            tri = np.trace(A3) / 6
            tailed = ((np.diag(A3) / 2) * (a.sum(0) - 2)).sum()
            cyc4 = 1 / 8 * (np.trace(A3.dot(a)) + np.trace(A2) - 2 * A2.sum())
            cus = a.dot(np.diag(np.exp(-a.dot(a).sum(1)))).dot(a).sum()

            deg = a.sum(0)
            star = 0
            for j in range(a.shape[0]):
                star += comb(int(deg[j]), 3)

            expy = torch.tensor([[tri, tailed, star, cyc4, cus]])

            E = np.where(A_arrays[i] > 0)
            edge_index = torch.Tensor(np.vstack((E[0], E[1]))).type(torch.int64)
            x = torch.ones(A_arrays[i].shape[0], 1)
            data_list.append(Data(edge_index=edge_index, x=x, y=expy))

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def separate_data(self, seed, fold_idx):
        return {'train': self.train_idx, 'valid': self.val_idx, 'test': self.test_idx}
