import os
from colorama import Fore, Style
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import torch

def datainfo(logger, args):
    if args.dataset == 'CIFAR10' or args.dataset== 'cifar10':
        print(Fore.YELLOW+'*'*80)
        logger.debug('CIFAR10')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 10
        img_mean, img_std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
        img_size = 32        
        
    elif args.dataset == 'CIFAR100' or args.dataset== 'cifar100':
        print(Fore.YELLOW+'*'*80)
        logger.debug('CIFAR100')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 100
        img_mean, img_std = (0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762) 
        img_size = 32        
        
    elif args.dataset == 'SVHN' or args.dataset== 'svhn':
        print(Fore.YELLOW+'*'*80)
        logger.debug('SVHN')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 10
        img_mean, img_std = (0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970) 
        img_size = 32
        
    elif args.dataset == 'Tiny-Imagenet' or args.dataset == 'tiny-imnet':
        print(Fore.YELLOW+'*'*80)
        logger.debug('T-IMNET')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 200
        img_mean, img_std = (0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821)
        img_size = 64

        
    elif args.dataset == 'CINIC' or args.dataset== 'cinic10':
        print(Fore.YELLOW+'*'*80)
        logger.debug('CINIC')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 10
        img_mean, img_std =(0.47889522, 0.47227842, 0.43047404),(0.24205776, 0.23828046, 0.25874835)
        img_size = 32

    elif args.dataset == 'STL10' or args.dataset== 'stl10':
        print(Fore.YELLOW+'*'*80)
        logger.debug('STL10')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 10
        img_mean, img_std =(0.47889522, 0.47227842, 0.43047404),(0.24205776, 0.23828046, 0.25874835)
        img_size = 32

    elif args.dataset == 'AIRCRAFT30' or args.dataset== 'aircraft30':
        print(Fore.YELLOW+'*'*80)
        logger.debug('AIRCRAFT30')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 30
        img_mean, img_std =(0.47889522, 0.47227842, 0.43047404),(0.24205776, 0.23828046, 0.25874835)
        img_size = 32
    elif args.dataset == 'PETS' or args.dataset== 'pets':
        print(Fore.YELLOW+'*'*80)
        logger.debug('PETS')
        print('*'*80 + Style.RESET_ALL)
        n_classes = 37
        img_mean, img_std =(0.485, 0.456, 0.406),(0.229, 0.224, 0.225)
        img_size = 32
        
    data_info = dict()
    data_info['n_classes'] = n_classes
    data_info['stat'] = (img_mean, img_std)
    data_info['img_size'] = img_size
    
    return data_info
ncls = {'stl10': 10, 'aircraft30': 30, 'pets': 37}
def dataload(args, augmentations, normalize, data_info):
    if args.dataset == 'CIFAR10' or args.dataset== 'cifar10':
        args.datapath = '../Datasets/data'
        train_dataset = datasets.CIFAR10(
            root=args.datapath, train=True, download=True, transform=augmentations)
        
        val_dataset = datasets.CIFAR10(
            root=args.datapath, train=False, download=True, transform=transforms.Compose([
            transforms.Resize(data_info['img_size']),
            transforms.ToTensor(),
            *normalize]))
        
    elif args.dataset == 'CIFAR100' or args.dataset== 'cifar100':
        args.datapath = '../Datasets/data'
        train_dataset = datasets.CIFAR100(
            root=args.datapath, train=True, download=True, transform=augmentations)
        val_dataset = datasets.CIFAR100(
            root=args.datapath, train=False, download=True, transform=transforms.Compose([
            transforms.Resize(data_info['img_size']),
            transforms.ToTensor(),
            *normalize]))
        
    elif args.dataset == 'SVHN' or args.dataset== 'svhn':
        args.datapath = '../Datasets/data'
        train_dataset = datasets.SVHN(
            root=args.datapath, split='train', download=True, transform=augmentations)
        val_dataset = datasets.SVHN(
            root=args.datapath, split='test', download=True, transform=transforms.Compose([
            transforms.Resize(data_info['img_size']),
            transforms.ToTensor(),
            *normalize]))
        
    elif args.dataset == 'Tiny-Imagenet' or args.dataset== 'tiny-imnet':
        args.datapath = "../../../Datasets/tiny-imagenet"
        train_dataset = datasets.ImageFolder(
            root=os.path.join(args.datapath, 'train'), transform=augmentations)
        val_dataset = datasets.ImageFolder(
            root=os.path.join(args.datapath, 'val'), 
            transform=transforms.Compose([
            transforms.Resize(data_info['img_size']), transforms.ToTensor(), *normalize]))

    elif args.dataset == 'CINIC' or args.dataset== 'cinic10':
        args.datapath =  "../../../Datasets/CINIC-10"
        train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(args.datapath, 'train'), transform=augmentations)
        val_dataset = torchvision.datasets.ImageFolder(root=os.path.join(args.datapath, 'val'),
                      transform=transforms.Compose([
                      transforms.Resize((data_info['img_size'],data_info['img_size'])), transforms.ToTensor(), *normalize]))

    
    return train_dataset, val_dataset


