from torch_geometric.loader import GraphSAINTRandomWalkSampler as RawGraphSAINTRandomWalkSampler
from torch_geometric.loader import RandomNodeSampler as RawRandomNodeSampler
from torch_geometric.loader import NeighborLoader as RawNeighborLoader
from torch_geometric.loader import ClusterLoader as RawClusterLoader
from torch_geometric.loader import ClusterData
from torch_cluster import random_walk
from torch.utils.data.sampler import SubsetRandomSampler
import torch
import numpy as np


def ImageNetSampler(data,
                    valid_size=0.1,
                    shuffle=True,
                    train=True,
                    **kwargs):

    batch_size = kwargs['batch_size']

    train_data, val_data = data
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.shuffle(indices)

    train_idx, val_idx = indices[split:], indices[:split]
    
    if train:
        idx = train_idx
    else:
        idx = val_idx

    sampler = SubsetRandomSampler(idx)
    loader = torch.utils.data.Dataloader(
        data, batch_size=batch_size, sampler=sampler
    )

    return loader


class ClusterSampler(RawClusterLoader):
    def __init__(self, data, shuffle=False, **kwargs):
        cluster_data = ClusterData(data,
                                   num_parts=kwargs["num_parts"])

        super().__init__(cluster_data,
                         shuffle=shuffle,
                         batch_size=kwargs["batch_size"],
                         num_workers=32)


class GraphSAINTRandomWalkSampler(RawGraphSAINTRandomWalkSampler):
    def __init__(self, data, shuffle=False, **kwargs):
        super().__init__(data,
                         shuffle=shuffle,
                         batch_size=kwargs["roots"],
                         walk_length=kwargs["walk_length"],
                         num_steps=kwargs["num_steps"],
                         sample_coverage=kwargs["sample_coverage"],
                         num_workers=32)


class RandomNodeSampler(RawRandomNodeSampler):
    def __init__(self, data, shuffle=False, **kwargs):
        super().__init__(data,
                         kwargs["num_parts"],
                         shuffle=shuffle,
                         num_workers=0)
        self.transform = kwargs.get('transform', None)

    def __collate__(self, node_idx):
        data = super().__collate__(node_idx)
        x = data.x
        if x.shape == (3, 224, 224):
            if not self.transform is None:
                if x.max() < 1.0:
                    raise ValueError('Image max value smaller than 1?')
                x = x.type(torch.uint8)
                x = self.transform(x).float()
            x = x / 255
            data.x = x
        return data


class NeighborSampler(RawNeighborLoader):
    def __init__(self, data, shuffle=False, **kwargs):
        super().__init__(data,
                         num_neighbors=kwargs["num_neighbors"],
                         shuffle=shuffle,
                         batch_size=kwargs["batch_size"],
                         num_workers=32)

    def sample(self, batch):
        batch = torch.tensor(batch)
        row, col, _ = self.adj_t.coo()

        pos_batch = random_walk(row, col, batch, walk_length=1, coalesced=False)[:, 1]

        neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ), dtype=torch.long)

        batch = torch.cat([batch, pos_batch, neg_batch], dim=0)

        return super().sample(batch)


loaders = {
    "GraphSAINTRandomWalkSampler": GraphSAINTRandomWalkSampler,
    "RandomNodeSampler": RandomNodeSampler,
    "NeighborSampler": NeighborSampler,
    "ClusterSampler": ClusterSampler,
    "ImageNetSampler": ImageNetSampler,
}

