from .cifar100 import get_cifar100_dataloaders_default, get_cifar100_dataloaders_crd
from .imagenet import get_imagenet_dataloaders_default, get_imagenet_dataloaders_crd
from .tiny_imagenet import get_tinyimagenet_dataloader, get_tinyimagenet_dataloader_sample
from .imagenet import get_imagenet_dataloaders_default_dist, get_imagenet_dataloaders_crd_dist


def get_mean_std(cfg):
    if cfg.DATASET.TYPE == "imagenet":
        mean=[0.485, 0.456, 0.406]
        std=[0.229, 0.224, 0.225]
    elif cfg.DATASET.TYPE == "cifar100":
        mean=[0.5071, 0.4867, 0.4408]
        std=[0.2675, 0.2565, 0.2761]
    elif cfg.DATASET.TYPE == "tiny_imagenet":
        mean=[0.4802, 0.4481, 0.3975]
        std=[0.2302, 0.2265, 0.2262]
    else:
        raise NotImplementedError(cfg.DATASET.TYPE)
    return mean, std


def get_dataset(cfg, args):
    if cfg.DATASET.TYPE == "cifar100":
        if args.JPEG_enable:
            mean = [0, 0, 0]
            std = [1/255., 1/255., 1/255.]
        else:
            mean = [0.5071, 0.4867, 0.4408]
            std = [0.2675, 0.2565, 0.2761]
        
        if cfg.DISTILLER.TYPE == "CRD":
            train_loader, val_loader, num_data = get_cifar100_dataloaders_crd(
                batch_size=cfg.SOLVER.BATCH_SIZE,
                val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
                num_workers=cfg.DATASET.NUM_WORKERS,
                k=cfg.CRD.NCE.K,
                mode=cfg.CRD.MODE,
                mean=mean, std=std)
        else:
            train_loader, val_loader, num_data = get_cifar100_dataloaders_default(
                batch_size=cfg.SOLVER.BATCH_SIZE,
                val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
                num_workers=cfg.DATASET.NUM_WORKERS,
                mean=mean, std=std)
        num_classes = 100
    
    elif cfg.DATASET.TYPE == "imagenet":
        if args.JPEG_enable:
            mean = [0., 0., 0.]
            std = [1/255., 1/255., 1/255.]
        else:
            mean =[0.485, 0.456, 0.406]
            std = [0.229, 0.224, 0.225]
            
        if args.distributed:
            if cfg.DISTILLER.TYPE == "CRD":
                train_loader, val_loader, num_data = get_imagenet_dataloaders_crd_dist(
                    batch_size=cfg.SOLVER.BATCH_SIZE,
                    val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
                    num_workers=cfg.DATASET.NUM_WORKERS,
                    k=cfg.CRD.NCE.K,
                    mean=mean, std=std)
            else:
                train_loader, val_loader, num_data = get_imagenet_dataloaders_default_dist(
                        batch_size=cfg.SOLVER.BATCH_SIZE,
                        val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
                        num_workers=cfg.DATASET.NUM_WORKERS,
                        mean=mean, std=std)
        else:
            if cfg.DISTILLER.TYPE == "CRD":
                train_loader, val_loader, num_data = get_imagenet_dataloaders_crd(
                    batch_size=cfg.SOLVER.BATCH_SIZE,
                    val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
                    num_workers=cfg.DATASET.NUM_WORKERS,
                    k=cfg.CRD.NCE.K,
                    mean=mean, std=std)
            else:
                train_loader, val_loader, num_data = get_imagenet_dataloaders_default(
                    batch_size=cfg.SOLVER.BATCH_SIZE,
                    val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
                    num_workers=cfg.DATASET.NUM_WORKERS,
                    mean=mean, std=std)
        num_classes = 1000
        
    elif cfg.DATASET.TYPE == "tiny_imagenet":
        if cfg.DISTILLER.TYPE in ("CRD", "CRDKD"):
            train_loader, val_loader, num_data = get_tinyimagenet_dataloader_sample(
                batch_size=cfg.SOLVER.BATCH_SIZE,
                val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
                num_workers=cfg.DATASET.NUM_WORKERS,
                k=cfg.CRD.NCE.K,)
        else:
            train_loader, val_loader, num_data = get_tinyimagenet_dataloader(
                batch_size=cfg.SOLVER.BATCH_SIZE,
                val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
                num_workers=cfg.DATASET.NUM_WORKERS,)
        num_classes = 200
    
    else:
        raise NotImplementedError(cfg.DATASET.TYPE)

    return train_loader, val_loader, num_data, num_classes
