import torch
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from .task import Task
from ..utils import log, log_dict


class Cifar10Task(Task):

    def get_model(self):
        return models.resnet18()
        # group_normalization = False
        # # resnet 20
        # if group_normalization:
        #     logging.getLogger("debug").info("Using group normalization")
        #     return get_resnet_model_gn(model="resnet20",
        #                                version=1,
        #                                dtype="fp32",
        #                                num_classes=10)
        # else:
        #     logging.getLogger("debug").info("Using Batch normalization")
        #     return get_resnet_model(model="resnet20",
        #                             version=1,
        #                             dtype="fp32",
        #                             num_classes=10)

    def get_loss_function(self):
        return torch.nn.modules.loss.CrossEntropyLoss()

    def get_dataloader(self):
        def cifar10(data_dir,
                    train,
                    download,
                    batch_size,
                    shuffle=None,
                    sampler_callback=None,
                    dataset_cls=datasets.CIFAR10,
                    drop_last=True,
                    **loader_kwargs):
            if sampler_callback is not None and shuffle is not None:
                raise ValueError

            cifar10_stats = {"mean": (0.4914, 0.4822, 0.4465),
                             "std": (0.2023, 0.1994, 0.2010)}

            if train:
                transform = transforms.Compose(
                    [
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, padding=4),
                        transforms.ToTensor(),
                        transforms.Normalize(cifar10_stats["mean"],
                                             cifar10_stats["std"]),
                    ]
                )
            else:
                transform = transforms.Compose(
                    [
                        transforms.ToTensor(),
                        transforms.Normalize(cifar10_stats["mean"],
                                             cifar10_stats["std"]),
                    ]
                )

            dataset = dataset_cls(root=data_dir,
                                  train=train,
                                  download=download,
                                  transform=transform)

            dataset.targets = torch.LongTensor(dataset.targets)

            sampler = sampler_callback(dataset) if sampler_callback else None
            log("Getting dataloader for cifar10:")
            log_dict(
                {
                    "Type": "Setup",
                    "Dataset": "cifar10",
                    "train": train,
                    "download": download,
                    "batch_size": batch_size,
                    "shuffle": shuffle,
                    "sampler": sampler.__str__() if sampler else None,
                }
            )
            return torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffle,
                                               sampler=sampler,
                                               drop_last=drop_last,
                                               **loader_kwargs)

        return cifar10
