import logging

from torchvision import transforms

import src.utils.vision.architectures.cifar as cifar_arch

from . import BaseModelWrapper, register_model

logger = logging.getLogger(__name__)


def _get_default_cifar10_transforms():
    test_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((32, 32)),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    train_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    return train_transforms, test_transforms


def _get_default_cifar100_transforms():
    statistics = [(0.4914, 0.482158, 0.446531), (0.247032, 0.243486, 0.261588)]
    test_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((32, 32)),
            transforms.Normalize(*statistics),
        ]
    )
    train_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(*statistics),
        ]
    )
    return train_transforms, test_transforms


def _get_default_svhn_transforms():
    statistics = [(0.437682, 0.44377, 0.472805), (0.19803, 0.201016, 0.197036)]
    test_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((32, 32)),
            transforms.Normalize(*statistics),
        ]
    )
    train_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((32, 32)),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(*statistics),
        ]
    )
    return train_transforms, test_transforms


@register_model("densenet121_cifar10")
def DenseNet121Cifar10(features_nodes=None, download=False, url=None, *args, **kwargs):
    model = cifar_arch.densenet.DenseNet121Small(10)
    model_name = "densenet121_cifar10"
    train_transforms, test_transforms = _get_default_cifar10_transforms()
    input_dim = (3, 32, 32)
    if features_nodes is None:
        features_nodes = {"view": "features", "linear": "linear"}
    return BaseModelWrapper(
        model,
        features_nodes,
        input_dim,
        test_transforms=test_transforms,
        train_transforms=train_transforms,
        download=download,
        url=url,
        model_name=model_name,
        *args,
        **kwargs,
    )


@register_model("densenet121_cifar100")
def DenseNet121Cifar100(features_nodes=None, download=False, url=None, *args, **kwargs):
    model = cifar_arch.densenet.DenseNet121Small(100)
    model_name = "densenet121_cifar100"
    train_transforms, test_transforms = _get_default_cifar100_transforms()
    input_dim = (3, 32, 32)
    if features_nodes is None:
        features_nodes = {"view": "features", "linear": "linear"}
    return BaseModelWrapper(
        model,
        features_nodes,
        input_dim,
        test_transforms=test_transforms,
        train_transforms=train_transforms,
        download=download,
        url=url,
        model_name=model_name,
        *args,
        **kwargs,
    )


@register_model("densenet121_svhn")
def DenseNet121SVHN(features_nodes=None, download=False, url=None, *args, **kwargs):
    model = cifar_arch.densenet.DenseNet121Small(10)
    model_name = "densenet121_svhn"
    train_transforms, test_transforms = _get_default_svhn_transforms()
    input_dim = (3, 32, 32)
    if features_nodes is None:
        features_nodes = {"view": "features", "linear": "linear"}
    return BaseModelWrapper(
        model,
        features_nodes,
        input_dim,
        test_transforms=test_transforms,
        train_transforms=train_transforms,
        download=download,
        url=url,
        model_name=model_name,
        *args,
        **kwargs,
    )


@register_model("vgg16_cifar10")
def VGG16Cifar10(features_nodes=None, download=False, url=None, *args, **kwargs):
    model = cifar_arch.vgg.VGG16(10)
    model_name = "vgg16_cifar10"
    train_transforms, test_transforms = _get_default_cifar10_transforms()
    input_dim = (3, 32, 32)
    if features_nodes is None:
        features_nodes = {"view": "features", "classifier": "linear"}
    return BaseModelWrapper(
        model,
        features_nodes,
        input_dim,
        test_transforms=test_transforms,
        train_transforms=train_transforms,
        download=download,
        url=url,
        model_name=model_name,
        *args,
        **kwargs,
    )


@register_model("vgg16_cifar100")
def VGG16Cifar100(features_nodes=None, download=False, url=None, *args, **kwargs):
    model = cifar_arch.vgg.VGG16(100)
    model_name = "vgg16_cifar100"
    train_transforms, test_transforms = _get_default_cifar100_transforms()
    input_dim = (3, 32, 32)
    if features_nodes is None:
        features_nodes = {"view": "features", "classifier": "linear"}
    return BaseModelWrapper(
        model,
        features_nodes,
        input_dim,
        test_transforms=test_transforms,
        train_transforms=train_transforms,
        download=download,
        url=url,
        model_name=model_name,
        *args,
        **kwargs,
    )


@register_model("vgg16_svhn")
def VGG16SVHN(features_nodes=None, download=False, url=None, *args, **kwargs):
    model = cifar_arch.vgg.VGG16(10)
    model_name = "vgg16_svhn"
    train_transforms, test_transforms = _get_default_svhn_transforms()
    input_dim = (3, 32, 32)
    if features_nodes is None:
        features_nodes = {"view": "features", "classifier": "linear"}
    return BaseModelWrapper(
        model,
        features_nodes,
        input_dim,
        test_transforms=test_transforms,
        train_transforms=train_transforms,
        download=download,
        url=url,
        model_name=model_name,
        *args,
        **kwargs,
    )


@register_model("resnet34_cifar10")
def ResNet34Cifar10(features_nodes=None, download=False, url=None, *args, **kwargs):
    model = cifar_arch.resnet.ResNet34(10)
    model_name = "resnet34_cifar10"
    train_transforms, test_transforms = _get_default_cifar10_transforms()
    input_dim = (3, 32, 32)
    if features_nodes is None:
        features_nodes = {"view": "features", "linear": "linear"}
    return BaseModelWrapper(
        model,
        features_nodes,
        input_dim,
        test_transforms=test_transforms,
        train_transforms=train_transforms,
        download=download,
        url=url,
        model_name=model_name,
        *args,
        **kwargs,
    )


@register_model("resnet34_cifar100")
def ResNet34Cifar100(features_nodes=None, download=False, url=None, *args, **kwargs):
    model = cifar_arch.resnet.ResNet34(100)
    model_name = "resnet34_cifar100"
    train_transforms, test_transforms = _get_default_cifar100_transforms()
    input_dim = (3, 32, 32)
    if features_nodes is None:
        features_nodes = {"view": "features", "linear": "linear"}
    return BaseModelWrapper(
        model,
        features_nodes,
        input_dim,
        test_transforms=test_transforms,
        train_transforms=train_transforms,
        download=download,
        url=url,
        model_name=model_name,
        *args,
        **kwargs,
    )


@register_model("resnet34_svhn")
def ResNet34SVHN(features_nodes=None, download=False, url=None, *args, **kwargs):
    model = cifar_arch.resnet.ResNet34(10)
    model_name = "resnet34_svhn"
    train_transforms, test_transforms = _get_default_svhn_transforms()
    input_dim = (3, 32, 32)
    if features_nodes is None:
        features_nodes = {"view": "features", "linear": "linear"}
    return BaseModelWrapper(
        model,
        features_nodes,
        input_dim,
        test_transforms=test_transforms,
        train_transforms=train_transforms,
        download=download,
        url=url,
        model_name=model_name,
        *args,
        **kwargs,
    )
