import sys, os
import glob
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
from einops import repeat
import torch_geometric.utils as geo_utils
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.data import Data
import torch_geometric
import glob
import pickle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_from_pickle(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data


class FolderDataLoader(Dataset):
    def __init__(self, 
                data_name='squirrel',
                folder_path='/home/user/data/graph_datasets/1_hop_nbd',
                mode='train'):
        self.mode = mode
        self.folder_path = os.path.join(folder_path, data_name)
        self.data = load_from_pickle(os.path.join(self.folder_path, "data.pkl"))
        self.num_classes = len(self.data.y.unique())
        self.num_features = self.data.x.shape[-1]
        self.indices = self.get_indices(mode)

    def get_indices(self, mode):
        # get the indices for train/val and test splits.

        indices = np.arange(self.data.x.shape[0])
        train_indices = indices[self.data.train_mask]
        test_indices = indices[self.data.test_mask]
        val_indices = indices[self.data.val_mask]
        if mode == 'train':
            return train_indices
        elif mode == 'validation':
            return val_indices
        elif mode == 'test':
            return test_indices
        else:
            ValueError(f"Invalid mode type {mode}")


    def __getitem__(self, index):
        data_index = self.indices[index]
        filename = os.path.join(self.folder_path, f"data_{data_index}.pkl")
        data = load_from_pickle(filename)
        return data


    def __len__(self):
        return len(self.indices)


class InductiveFolderLoader(Dataset):
    def __init__(self, 
                 folder_path: str,
                 filename='example.pt',
                 mode='train'):
        self.mode = mode
        self.folder_path = folder_path
        self.filename = filename


    
