"""
Loader Utils
"""

import torch

from torch_geometric.datasets import TUDataset

import networkx as nx
from torch_geometric.utils.convert import to_networkx
import numpy as np
import json

import random

list_binary = ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-BINARY", \
                                                            "REDDIT-MULTI-5K"]

class data_loader:

    def __init__(self, name_dataset):
        # Load TUDataset
        dataset = TUDataset(root='./datasets/'+name_dataset, name=name_dataset)

        self.num_examples = len(dataset)
        self.Adj_list = []

        self.y_list = []
        self.X_list = []

        for i in range(len(dataset)):
            G = to_networkx(dataset[i])
            G = nx.to_undirected(G)
            adj = nx.adjacency_matrix(G)
            adj = torch.FloatTensor(adj.toarray())
            self.Adj_list.append(adj)

            if np.isin(name_dataset, list_binary):
                self.X_list.append(torch.FloatTensor(adj.sum(axis = 1)).unsqueeze(1))
                self.input_dim = 1
            else:
                self.X_list.append(torch.FloatTensor(dataset[i].x.float()))
                self.input_dim = self.X_list[0].shape[1]

            self.y_list.append(dataset[i].y)


        self.num_classes = len(np.unique(self.y_list))

        # Load the train/validation/test splits
        f = open('../data_splits/' + name_dataset + '_splits.json',)
        self.split = json.load(f)

    def get_fold_data(self, int_fold):
        test_indices = self.split[int_fold]['test']
        train_indices = self.split[int_fold]['model_selection'][0]['train']

        val_indices = self.split[int_fold]['model_selection'][0]['validation']

        Adj_test = [self.Adj_list[x] for x in test_indices]
        X_test = [self.X_list[x] for x in test_indices]
        y_test = [self.y_list[x] for x in test_indices]

        Adj_train = [self.Adj_list[x] for x in train_indices]
        X_train = [self.X_list[x] for x in train_indices]
        y_train = [self.y_list[x] for x in train_indices]

        Adj_val = [self.Adj_list[x] for x in val_indices]
        X_val = [self.X_list[x] for x in val_indices]
        y_val = [self.y_list[x] for x in val_indices]
        print(len(Adj_train), len(Adj_test), len(Adj_val))

        return Adj_train, X_train, y_train, Adj_val, X_val, y_val, Adj_test, X_test, y_test

    def get_adv_fold_data(self, int_fold, Adj_adv, X_adv, y_adv):
        test_indices = self.split[int_fold]['test']
        train_indices = self.split[int_fold]['model_selection'][0]['train']

        val_indices = self.split[int_fold]['model_selection'][0]['validation']

        Adj_test = [self.Adj_list[x] for x in test_indices]
        X_test = [self.X_list[x] for x in test_indices]
        y_test = [self.y_list[x] for x in test_indices]

        Adj_train = [self.Adj_list[x] for x in train_indices] + [adj for adj in Adj_adv]
        X_train = [self.X_list[x] for x in train_indices] + [x for x in X_adv]
        y_train = [self.y_list[x] for x in train_indices] + [y for y in y_adv]

        index_list = [i for i in range(len(Adj_train))]
        random.shuffle(index_list)

        Adj_train = [Adj_train[i] for i in index_list]
        X_train = [X_train[i] for i in index_list]
        y_train = [y_train[i] for i in index_list]


        Adj_val = [self.Adj_list[x] for x in val_indices]
        X_val = [self.X_list[x] for x in val_indices]
        y_val = [self.y_list[x] for x in val_indices]
        print(len(Adj_train), len(Adj_test), len(Adj_val))

        return Adj_train, X_train, y_train, Adj_val, X_val, y_val, Adj_test, X_test, y_test
