import os, sys
import torch
import torch.nn.functional as F
import torch.nn as nn
import wandb
import numpy as np
import pickle
import glob


from torch.utils.data import Dataset
from torch_geometric.data import Data



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_from_pickle(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data


class FolderDataLoader(Dataset):
    def __init__(self, 
                dataset_name='flipflop_prop_flip_0.7_total_edges_50_num_samples_10000',
                folder_path='/home/user/data/graph_datasets/flipflop',
                mode='train', 
                file_prefix='example',
                edge_wise_file_prefix='graph_example',
                num_samples=5000,):
        self.mode = mode
        self.dataset_folder = os.path.join(folder_path, dataset_name)
        self.file_prefix = file_prefix
        self.filenames = self.get_filenames(self.dataset_folder)
        self.num_samples = num_samples

    def get_filenames(self, dataset_folder: str):
        filenames = glob.glob(os.path.join(dataset_folder, f'{self.file_prefix}_*.pt'))
        return filenames

    def __getitem__(self, index):
        if self.mode == 'test':
            fileindex = len(self.filenames) - index -1
        else:
            fileindex = index
        filename = f"{self.file_prefix}_{fileindex:04}.pt"
        graph = load_from_pickle(os.path.join(self.dataset_folder, filename))
        return graph

    def __len__(self):
        return self.num_samples
    


        