import torch
from torch.utils.data import DataLoader, Dataset

class GraphDataset(Dataset):
    def __init__(self, indices, labels):
        self.indices = indices
        self.labels = labels

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.indices[idx], self.labels[self.indices[idx]]

def create_dataloaders(graph, train_idx, val_idx, test_idx, batch_size=1024):
    def oversample(indices, labels):
        normal = indices[labels[indices] == 0]
        anomaly = indices[labels[indices] == 1]
        n_normal = len(normal)
        n_anomaly = len(anomaly)
        if n_anomaly == 0:
            return indices
        repeat_times = n_normal // n_anomaly if n_anomaly < n_normal else 1
        anomaly_upsampled = anomaly.repeat(repeat_times)
        combined = torch.cat([normal, anomaly_upsampled], dim=0)
        if len(combined) < len(indices):
            extra = normal[:len(indices) - len(combined)]
            combined = torch.cat([combined, extra], dim=0)
        combined = combined[torch.randperm(len(combined))]
        return combined

    train_indices = oversample(train_idx, graph.ndata['label'])
    val_indices = val_idx
    test_indices = test_idx

    train_dataset = GraphDataset(train_indices, graph.ndata['label'])
    val_dataset = GraphDataset(val_indices, graph.ndata['label'])
    test_dataset = GraphDataset(test_indices, graph.ndata['label'])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader
