'''
    From the LoG 2022 tutorial, slightly adapted to include other source graphs and 
    path counting as a target
'''

import os.path as osp
import scipy.io as sio
import numpy as np
import torch
import networkx as nx

from collections import Counter
from tqdm import tqdm

from scipy.special import comb
from torch_geometric.data import InMemoryDataset, Data

from graphgps.loader.data_generation_utils import generate_RR_mat_dataset

DATA_PATH = '/data/'
PATHS_MAX_LENGTH = 4
DEFAULT_LARGER_RR_PARAMETERS = '20-5_30-5_40-5_60-4'

class SyntheticDataset(InMemoryDataset):
    def __init__(self, root, name, transform=None, pre_transform=None, include_paths=False, parameters=None, y_func='mixedmoments', largest_component=False):
        self.name = name
        self.largest_component = largest_component
        if not parameters:
            parameters = DEFAULT_LARGER_RR_PARAMETERS
        self.parameters = parameters
        self.include_paths = include_paths
        self.y_func = y_func
        source = "randomgraph"+f"_RR_{parameters}" + self.y_func
        self.source =source
        super(SyntheticDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

        # Normalize as in GNN-AKs
        self.data.y = self.data.y / self.data.y.std(0)

        a = sio.loadmat(osp.join(self.raw_dir, self.raw_file_names[0]))
        num_graphs = len(a["A"][0])
        num_train = int(0.8*num_graphs)
        num_val = int(0.1*num_graphs)
        all_samples = torch.arange(num_graphs)
        self.train_idx = all_samples[:num_train]
        self.val_idx = all_samples[num_train:num_train+num_val]
        self.test_idx = all_samples[num_train+num_val:]
        print(len(self.train_idx))
        #self.val_idx = torch.from_numpy(a["val_idx"][0])
        print(len(self.val_idx))
        #self.test_idx = torch.from_numpy(a["test_idx"][0])
        print(len(self.test_idx))

    @property
    def raw_dir(self):
        return DATA_PATH

    @property
    def raw_file_names(self):
        return [f"{self.source}.mat"]

    @property
    def processed_file_names(self):
        name = "data.pt"
        return name

    @property
    def processed_dir(self):
        name = "processed"
        return osp.join(self.root, self.name, name)

    def download(self):
        print(f'Generating dataset with parameters {self.parameters}...')
        generate_RR_mat_dataset(osp.join(self.raw_dir, self.raw_file_names[0]), self.parameters, largest_component=self.largest_component)
        return

    @property
    def num_tasks(self):
        return 1

    def process(self):
        # Read data into huge `Data` list.
        b = self.processed_paths[0]
        a = sio.loadmat(f"./data/{self.source}.mat")
        # list of adjacency matrix
        A = a["A"][0]
        
        data_list = []
        print(len(A))
        for i in tqdm(range(len(A))):
            a = A[i]
            A2 = a.dot(a)
            A3 = A2.dot(a)
            A4 = A3.dot(a)
            deg = a.sum(0)

            #total = 1*deg + A2.sum(0)/2 + A3.sum(0)/4 #+ 9*A4.sum(0)/9s
            total = -6 + (8)*deg + (1)*A2.sum(0) + A3.sum(0) #1*deg + A2.sum(0)/2 + A3.sum(0)/4 #+ 9*A4.sum(0)/9 #4*deg + 2*A2.sum(0) + #A3.sum(0) #4*deg + A3.sum(0)     # + A3.sum(0)/4 + A4.sum(0)/9 # mixing graph moments
            
            expy = torch.tensor([np.mean(total)])

            E = np.where(A[i] > 0)
            edge_index = torch.Tensor(np.vstack((E[0], E[1]))).type(torch.int64)
            x = torch.ones(A[i].shape[0], 1)
            edge_attr=torch.ones(len(E[0]), 1)
            print(x.shape)
            print(edge_attr.shape)
            #x = torch.randn_like(x)
            # make x degrees
            #x = torch.tensor(deg).float()
            #x = x.reshape(A[i].shape[0], 1)
            data_list.append(Data(edge_index=edge_index, x=x, y=expy, edge_attr=edge_attr))

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def separate_data(self, seed):
        return {"train": self.train_idx, "val": self.val_idx, "test": self.test_idx}
