import torch
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from .cifar10_models.resnet_gn import get_resnet_model_gn
from .cifar10_models.resnet import get_resnet_model

from .task import Task
from ..utils import log, log_dict


class Cifar100Task(Task):

    def get_model(self):
        group_normalization = True # based on https://arxiv.org/pdf/2003.00295.pdf
        # resnet 18
        if group_normalization:
            return get_resnet_model_gn(model="resnet18",
                                       version=1,
                                       dtype="fp32",
                                       num_classes=100)
        else:
            return get_resnet_model(model="resnet18",
                                    version=1,
                                    dtype="fp32",
                                    num_classes=100)

    def get_loss_function(self):
        return torch.nn.modules.loss.CrossEntropyLoss()

    def get_dataloader(self):
        def cifar100(data_dir,
                    train,
                    download,
                    batch_size,
                    shuffle=None,
                    sampler_callback=None,
                    dataset_cls=datasets.CIFAR100,
                    drop_last=True,
                    **loader_kwargs):
            if sampler_callback is not None and shuffle is not None and train:
                raise ValueError

            cifar100_stats = {"mean": (0.4914, 0.4822, 0.4465),
                             "std": (0.2023, 0.1994, 0.2010)}

            if train:
                transform = transforms.Compose(
                    [
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, padding=4),
                        transforms.ToTensor(),
                        transforms.Normalize(cifar100_stats["mean"],
                                             cifar100_stats["std"]),
                    ]
                )
            else:
                transform = transforms.Compose(
                    [
                        transforms.ToTensor(),
                        transforms.Normalize(cifar100_stats["mean"],
                                             cifar100_stats["std"]),
                    ]
                )

            dataset = dataset_cls(root=data_dir,
                                  train=train,
                                  download=download,
                                  transform=transform)

            dataset.targets = torch.LongTensor(dataset.targets)

            sampler = sampler_callback(dataset) if sampler_callback else None
            log("Getting dataloader for cifar100:")
            log_dict(
                {
                    "Type": "Setup",
                    "Dataset": "cifar100",
                    "train": train,
                    "download": download,
                    "batch_size": batch_size,
                    "shuffle": shuffle,
                    "sampler": sampler.__str__() if sampler else None,
                }
            )
            return torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffle,
                                               sampler=sampler,
                                               drop_last=drop_last,
                                               **loader_kwargs)

        return cifar100

if __name__ == "__main__":
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    model = get_resnet_model_gn(model="resnet18",
                                       version=1,
                                       dtype="fp32",
                                       num_classes=100)
    print(count_parameters(model))
