from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import ToTensor, Normalize, Compose
import torchvision.transforms as transforms
import torch

import socket

from timm.data import create_transform

from .imagenet import imagenet

_hostname = socket.gethostname()
_factory_datapath = {}
_factory_datapath = {
    "mnist": "~/datasets/mnist",
    "cifar10": "~/datasets/cifar10",
    "imagenet": "~/datasets/imagenet",
}

_factory_dataset = {
    "mnist": MNIST,
    "cifar10": CIFAR10,
    "imagenet": imagenet
}

def _get_transform(config, is_train: bool = True):
    dataset = config.dataset
    model = config.model
    transform = None
    transform_list = []
    if dataset == "cifar10":
        if "resnet" in model or "vgg" in model or "allconv" in model:
            transform_list.extend([
                ToTensor(),
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
        else:
            transform_list.extend([
                ToTensor(), 
                Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
            ])
        transform = Compose(transform_list)
    elif dataset == "mnist":
        if "resnet" in model or "vgg" in model or "allconv" in model:
            raise NotImplementedError() # TODO:
        transform_list.append(ToTensor())
        transform = Compose(transform_list)
    elif dataset == "imagenet":
        assert "resnet" in model or "vgg" in model or "allconv" in model
        transform = create_transform(
            input_size=224,
            is_training=is_train,
            use_prefetcher=False,
            scale=[0.08, 1.0],
            ratio=[3./4., 4./3.],
            hflip=0.5,
            vflip=0.0,
            color_jitter=0.4,
            auto_augment="rand-m9-mstd0.5-inc1",
            re_prob=0.6,
            re_mode="pixel",
        )
    else:
        raise KeyError(f"Invalid dataset {dataset}")
    
    # TODO: No CutMix, MixUp, and etc.
        
    return transform

def give_dataset(config, is_train: bool = True):
    dataset = config.dataset
    datapath = _factory_datapath[dataset]
    transform = _get_transform(config, is_train)
    _out_object = _factory_dataset[dataset](
        root=datapath, train=is_train, transform=transform, download=True
    )
    return _out_object

def give_dataloader(config, dataset, is_train: bool = True):
    if _hostname not in ["zebra", "server231-SYS-4028GR-TR"]:
        try:
            from ffrecord.torch import DataLoader
        except:
            print(f"Warning-{__file__}: No ffrecord installed!")
        loader_class = DataLoader
    else:
        loader_class = torch.utils.data.DataLoader
    
    if config.is_dist:
        sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        return loader_class(
            dataset=dataset,
            batch_size=config.bs,
            num_workers=config.nw,
            drop_last=is_train,
            sampler=sampler
        )
    else:
        return loader_class(
            dataset=dataset,
            batch_size=config.bs,
            shuffle=is_train,
            num_workers=config.nw,
            drop_last=is_train
        )
