from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import torch
import numpy as np
import scipy.io as scio
import cv2

class Train_dataset(Dataset):
    def __init__(self, train_data, train_gt, dataset):
        super(Train_dataset, self).__init__()
        self.data = train_data
        self.gt = train_gt
        self.N = self.data.shape[0]
        self.order = np.arange(self.N)
        if dataset == 'CIFAR100':
            self.norm = transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.4914, 0.4822, 0.4465],
                                                                 [0.2023, 0.1994, 0.2010]),
                                            transforms.RandomHorizontalFlip(),
                                            # transforms.RandomErasing(scale=(0.04, 0.2), ratio=(0.5, 2)),
                                            # transforms.RandomVerticalFlip(),
                                            # transforms.RandomRotation(90),
                                            transforms.RandomCrop(32, padding=4, padding_mode='constant'),  # padding_mode='edge'
                                            ])
        if dataset == 'CIFAR10':
            # self.norm = transforms.Compose([
            #                                 transforms.ToTensor(),
            #                                 transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
            #                                                      (0.24703233, 0.24348505, 0.26158768)),
            #                                 transforms.RandomHorizontalFlip(),
            #                                 # transforms.RandomErasing(scale=(0.04, 0.2), ratio=(0.5, 2)),
            #                                 # transforms.RandomVerticalFlip(),
            #                                 # transforms.RandomRotation(90),
            #                                 transforms.RandomCrop(32, padding=4, padding_mode='constant'),  # padding_mode='edge'
            #                                 ])
            self.norm = transforms.Compose([

                                            transforms.RandomCrop(32, padding=4, padding_mode='constant'),  # padding_mode='edge'
                                            transforms.RandomHorizontalFlip(),
                                            # transforms.RandomErasing(scale=(0.04, 0.2), ratio=(0.5, 2)),
                                            # transforms.RandomVerticalFlip(),
                                            # transforms.RandomRotation(90),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                                                 (0.24703233, 0.24348505, 0.26158768)),
                                        ])

    def __getitem__(self, index):
        data = Image.fromarray(self.data[index].transpose(1, 2, 0))
        data = np.array(self.norm(data))
        return torch.from_numpy(data), \
               torch.from_numpy(self.gt[index])

    def __len__(self):
        return self.N


class Test_dataset(Dataset):
    def __init__(self, test_data, test_gt, dataset):
        super(Test_dataset, self).__init__()
        self.data = test_data
        self.gt = test_gt
        self.N = self.data.shape[0]
        if dataset == 'CIFAR100':
            self.norm = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465],
                                     [0.2023, 0.1994, 0.2010]),
                # transforms.RandomHorizontalFlip()
            ])
        if dataset == 'CIFAR10':
            self.norm = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                     (0.24703233, 0.24348505, 0.26158768)),
                # transforms.RandomHorizontalFlip()
            ])

    def __getitem__(self, index):
        data = Image.fromarray(self.data[index].transpose(1, 2, 0))
        data = np.array(self.norm(data))
        return torch.from_numpy(data), \
               torch.from_numpy(self.gt[index])

    def __len__(self):
        return self.N



class Finetune_dataset(Dataset):
    def __init__(self, train_data, train_gt):
        super(Finetune_dataset, self).__init__()
        self.data = train_data
        self.gt = train_gt
        self.N = self.data.shape[0]

    def __getitem__(self, index):
        return self.data[index], self.gt[index]

    def __len__(self):
        return self.N