import os
import shutil
import numpy as np
import torch.utils.data
import torch
import time

from utils.get_data_iter import get_cifar10_iter, get_cifar10_all_iter
from data_providers import DataProvider


class CIFAR10DataProvider(DataProvider):
    def __init__(self, save_path=None, train_batch_size=256, test_batch_size=256, valid_size=None, num_workers=24, manual_seed=12,
                 load_type='dali', local_rank=0, world_size=1, **kwargs):

        self._save_path = save_path
        self.valid = None
        if valid_size is not None:
            pass
        else:
            self.train = get_cifar10_iter(data_type='train', image_dir=self.train_path,
                                                batch_size=train_batch_size, num_threads=num_workers,
                                                device_id=local_rank, manual_seed=manual_seed,
                                                num_gpus=torch.cuda.device_count(), crop=self.image_size,
                                                val_size=self.image_size, world_size=world_size, local_rank=local_rank)
        #     self.test = get_cifar10_iter(data_type='val', image_dir=self.valid_path, manual_seed=manual_seed,
        #                                         batch_size=test_batch_size, num_threads=num_workers, device_id=local_rank,
        #                                         num_gpus=torch.cuda.device_count(), crop=self.image_size,
        #                                         val_size=256, world_size=world_size, local_rank=local_rank)
        # if self.valid is None:
        #     self.valid = self.test

    @staticmethod
    def name():
        return 'cifar10'

    @property
    def data_shape(self):
        return 3, self.image_size, self.image_size  # C, H, W

    @property
    def n_classes(self):
        return 1000

    @property
    def save_path(self):
        if self._save_path is None:
            self._save_path = '/userhome/data/imagenet'
        return self._save_path

    @property
    def data_url(self):
        raise ValueError('unable to download ImageNet')

    @property
    def train_path(self):
        return os.path.join(self.save_path, 'train')

    @property
    def valid_path(self):
        return os.path.join(self._save_path, 'val')

    @property
    def resize_value(self):
        return 256

    @property
    def image_size(self):
        return 224