# generate application 1 dataset
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
import pickle
from pathlib import Path
import yaml
import re
import itertools
from torch_geometric.data import DataLoader
from utils import get_diracs
from tqdm import tqdm
from torch_geometric.datasets import TUDataset

class COLLAB_val(InMemoryDataset):
    def __init__(self, config:dict):
        self.config = config
        self.data_path = Path(config['data_dir'])
        super(COLLAB_val, self).__init__(root=self.data_path)
        self.data, self.slices = torch.load(self.processed_paths[0])
    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['data.pt']
    def download(self):
        # Download to `self.raw_dir`.
        pass
    def get_idx_split(self, split_type = 'Random'):
        data_idx = np.arange(2389)
        train_idx = data_idx
        return {'train':torch.tensor(train_idx,dtype = torch.long)}
    def process(self):
        path_to_dataset = './dataset/raw/'
        dataset = TUDataset(root = path_to_dataset, name = 'COLLAB')
        total_sample = len(dataset)
        train_sample = int(0.8*total_sample)
        train_sample_end = int(0.9*total_sample)
        test_sample = train_sample_end - train_sample

        save_path = './dataset/'
        save_file = open(save_path+'shuffled_list.pkl', 'rb')
        list_to_load = pickle.load(save_file)

        data_idx = list_to_load['shuffled_list']
        test_list = data_idx[train_sample:train_sample_end]
        testset = dataset[test_list]
        data_list = []
        for task_index in tqdm(range(test_sample)):
            test_list = []
            for i in range(1):
                test_list.append(testset[task_index])
            train_loader = DataLoader(test_list, batch_size = 1, shuffle = False)
            for data in train_loader:
                data_prime = get_diracs(data, 1, sparse = True, effective_volume_range=0.15, receptive_field = 5)
                train_x = data_prime.x
                train_batch = data_prime.batch
                train_edge_index = data_prime.edge_index
                train_locations = data_prime.locations
            final_data = Data(x = train_x, edge_index = train_edge_index, train_batch = train_batch, locations = train_locations)
            data_list.append(final_data)
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        

if __name__ == '__main__':
    import os
    configs = Path('./configs')
    for cfg in configs.iterdir():
        if str(cfg).startswith("configs/config"):
            cfg_dict = yaml.safe_load(cfg.open('r'))
            dataset = COLLAB_val(cfg_dict['val'])
