import os, pickle, random
from load_data import load_data
import torch
import numpy as np
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, random_split, Subset
from torch_geometric.loader import DataLoader

class GNNDataset(Dataset):
    def __init__(self, args):
        self.args = args
        self.data = [load_data(i, j, self.args) for i in range(self.args['num_of_simulations']) for j in range(self.args['num_of_steps']-1)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

class GNNDataset_online_loader(Dataset):
    def __init__(self, args, data_number=20000):
        self.data_number = data_number
        self.args = args

    def __len__(self):
        return self.data_number

    def __getitem__(self, idx):
        simulation_num = idx // (self.args['num_of_steps'] - 1)
        step_num = idx % (self.args['num_of_steps'] - 1)
        return load_data(simulation_num, step_num, self.args)

def data_loader(args):
    dataset = GNNDataset(args)
    train_size = int(0.8 * len(dataset))
    val_size = int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=False, sampler=sampler)
    val_loader = DataLoader(val_dataset, batch_size=args['batch_size'], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False)
    return train_loader, val_loader, test_loader

def data_loader_index(args):
    dataset = GNNDataset(args)
    with open('./indices/indices_100_80_10_10.pkl', 'rb') as f:  # indices_2_15-2-2.pkl indices_100_80_10_10.pkl
        indices = pickle.load(f)
    train_idx = indices['train_idx']
    test_idx = indices['test_idx']
    val_idx = indices['val_idx']
    print(len(train_idx))

    train_dataset = Subset(dataset, train_idx)
    print(len(train_dataset))
    test_dataset = Subset(dataset, test_idx)
    val_dataset = Subset(dataset, val_idx)

    # train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(train_dataset, batch_size=args['batch_size'], sampler=sampler)  # , shuffle=True #
    print(len(train_loader))
    val_loader = DataLoader(val_dataset, batch_size=args['batch_size'], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False)
    return train_loader, val_loader, test_loader