import torch
import numpy as np
from torch.utils.data.sampler import Sampler

from utils.names import Datasets
from datasets.shapenet.shapes_3d_transforms import Compose, Normalize, RandomBackground, ToTensor, CenterCrop, RandomCrop, ColorJitter, \
    RandomNoise, RandomFlip, RandomPermuteRGB


def load_data(
        arguments,
        dataset,
        batch_size,
        data_path,
        num_workers=4,
):
    if dataset == Datasets.ShapeNet.value:
        from .aligned_dataset import ShapeNetDataset

        def normalize(x):
            return x * 2 - 1

        def to_numpy(image):
            image.convert("RGB")
            return [np.asarray(image, dtype=np.float32) / 255]

        image_trans = Compose([
            to_numpy,
            CenterCrop((224, 224), (128, 128)),
            RandomBackground(((240, 240), (240, 240), (240, 240))),
            ToTensor(),
            lambda x: x[0],
            normalize
        ])

        dataset_params = {
            'annot_path': f'{data_path}/ShapeNet.json',
            'model_path': f'{data_path}/ShapeNetVox32',
            'image_path': f'{data_path}/ShapeNetRendering'
        }

        trainset = ShapeNetDataset(
            **dataset_params,
            image_transforms=image_trans,
            split='train',
            background=(0, 0, 0),
            view_num=arguments.num_of_views,
            data_direction=dataset
        )

        valset = ShapeNetDataset(
            **dataset_params,
            image_transforms=image_trans,
            split='test',
            mode='first',
            background=(0, 0, 0),
            view_num=arguments.num_of_views,
            data_direction=dataset
        )

    else:
        raise NotImplementedError(f"no such dataset {dataset}")

    loader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, num_workers=num_workers, drop_last=False, shuffle=True)

    val_loader = torch.utils.data.DataLoader(
        valset, batch_size=batch_size, num_workers=num_workers, drop_last=False, shuffle=True)

    return loader, val_loader
