"""Classification networks to evaluate heterogeneity."""
import torch
from pytorchcv.model_provider import get_model as ptcv_get_model
from torchvision import datasets
from torchvision import transforms as T

from ..pytorch_cifar.models.vgg import VGG as vgg_cifar
from .base import BaseNet


class CIFAR10Based(BaseNet):

    def __init__(self, split='val'):
        self.split = split
        super().__init__()

    def get_dataset(self):
        transform = T.Compose([
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        if self.split == 'val':
            train = False
        elif self.split == 'train':
            train = True
        else:
            raise ValueError(f'Unkown split {self.split}')

        return datasets.CIFAR10('datasets', train=train, download=True,
                                transform=transform)

    def get_dataset_name(self):
        return f'cifar10_{self.split}'

    def get_w(self):
        return self.last_layer.weight.detach()

    def get_intercept(self):
        return self.last_layer.bias.detach()

    def logits_to_scores(self, y_logits):
        return torch.nn.functional.softmax(y_logits, dim=1)

    def get_class_names(self):
        return [
            'airplane',
            'automobile',
            'bird',
            'cat',
            'deer',
            'dog',
            'frog',
            'horse',
            'ship',
            'truck',
        ]

    @staticmethod
    def _get_choices():
        return {}

    def create_model(self):
        choices = self._get_choices()
        if isinstance(choices, str):
            choice = choices
        else:
            if self.type in choices:
                choice = choices[self.type]
            else:
                raise ValueError(f'Unknown version {self.type} for '
                                 f'{self.__class__.__name__.lower()}.'
                                 f'Choices: {list(choices.keys())}.')
        return ptcv_get_model(choice, pretrained=True)

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.output
        model.output = torch.nn.Identity()
        return model, last_layer


class VGG(CIFAR10Based):

    @staticmethod
    def create_model():
        model = vgg_cifar('VGG11')
        model = torch.nn.DataParallel(model)
        checkpoint = torch.load('calibration/xp_nn_calibration/pytorch_cifar/checkpoint/ckpt_vgg11.pth',
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['net'])
        return model

    @staticmethod
    def create_truncated_model():
        model = VGG.create_model()
        last_layer = model.module.classifier
        model.module.classifier = torch.nn.Identity()
        return model, last_layer


class ResNet(CIFAR10Based):

    def __init__(self, type='20', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '20': 'resnet20_cifar10',
            '56': 'resnet56_cifar10',
            '110': 'resnet110_cifar10',
            '164bn': 'resnet164bn_cifar10',
            '272bn': 'resnet272bn_cifar10',
            '542bn': 'resnet542bn_cifar10',
            '1001': 'resnet1001_cifar10',
            '1202': 'resnet1202_cifar10',
        }


class PreResNet(CIFAR10Based):

    def __init__(self, type='20', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '20': 'preresnet20_cifar10',
            '56': 'preresnet56_cifar10',
            '110': 'preresnet110_cifar10',
            '164bn': 'preresnet164bn_cifar10',
            '272bn': 'preresnet272bn_cifar10',
            '542bn': 'preresnet542bn_cifar10',
            '1001': 'preresnet1001_cifar10',
            '1202': 'preresnet1202_cifar10',
        }


class ResNext(CIFAR10Based):

    def __init__(self, type='29_32x4d', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '29_32x4d': 'resnext29_32x4d_cifar10',
            '29_16x64d': 'resnext29_16x64d_cifar10',
            '272_1x64d': 'resnext272_1x64d_cifar10',
            '272_2x32d': 'resnext272_2x32d_cifar10',
        }


class SEResNet(CIFAR10Based):

    def __init__(self, type='20', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '20': 'seresnet20_cifar10',
            '56': 'seresnet56_cifar10',
            '110': 'seresnet110_cifar10',
            '164bn': 'seresnet164bn_cifar10',
            '272bn': 'seresnet272bn_cifar10',
            '542bn': 'seresnet542bn_cifar10',
        }


class SEPreResNet(CIFAR10Based):

    def __init__(self, type='20', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '20': 'sepreresnet20_cifar10',
            '56': 'sepreresnet56_cifar10',
            '110': 'sepreresnet110_cifar10',
            '164bn': 'sepreresnet164bn_cifar10',
            '272bn': 'sepreresnet272bn_cifar10',
            '542bn': 'sepreresnet542bn_cifar10',
        }


class DIAResNet(CIFAR10Based):

    def __init__(self, type='20', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '20': 'diaresnet20_cifar10',
            '56': 'diaresnet56_cifar10',
            '110': 'diaresnet110_cifar10',
            '164bn': 'diaresnet164bn_cifar10',
        }


class DIAPreResNet(CIFAR10Based):

    def __init__(self, type='20', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '20': 'diapreresnet20_cifar10',
            '56': 'diapreresnet56_cifar10',
            '110': 'diapreresnet110_cifar10',
            '164bn': 'diapreresnet164bn_cifar10',
        }


class PyramidNet(CIFAR10Based):

    def __init__(self, type='110_a48', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '110_a48': 'pyramidnet110_a48_cifar10',
            '110_a84': 'pyramidnet110_a84_cifar10',
            '110_a270': 'pyramidnet110_a270_cifar10',
            '164_a270_bn': 'pyramidnet164_a270_bn_cifar10',
            '200_a240_bn': 'pyramidnet200_a240_bn_cifar10',
            '236_a220_bn': 'pyramidnet236_a220_bn_cifar10',
            '272_a200_bn': 'pyramidnet272_a200_bn_cifar10',
        }


class DenseNet(CIFAR10Based):

    def __init__(self, type='40_k12', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '40_k12': 'densenet40_k12_cifar10',
            '40_k12_bc': 'densenet40_k12_bc_cifar10',
            '40_k24_bc': 'densenet40_k24_bc_cifar10',
            '40_k36_bc': 'densenet40_k36_bc_cifar10',
            '100_k12': 'densenet100_k12_cifar10',
            '100_k24': 'densenet100_k24_cifar10',
            '100_k12_bc': 'densenet100_k12_bc_cifar10',
            '190_k40_bc': 'densenet190_k40_bc_cifar10',
            '250_k24_bc': 'densenet250_k24_bc_cifar10',
        }


class WideResNet(CIFAR10Based):

    def __init__(self, type='16_10', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '16_10': 'wrn16_10_cifar10',
            '28_10': 'wrn28_10_cifar10',
            '40_8': 'wrn40_8_cifar10',
        }


class WideResNet1b(CIFAR10Based):

    def __init__(self, type='20_10', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '20_10': 'wrn20_10_1bit_cifar10',
        }


class WideResNet32b(CIFAR10Based):

    def __init__(self, type='20_10', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '20_10': 'wrn20_10_32bit_cifar10',
        }


class XDenseNet(CIFAR10Based):

    def __init__(self, type='40_2_k24_bc', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '40_2_k24_bc': 'xdensenet40_2_k24_bc_cifar10',
            '40_2_k36_bc': 'xdensenet40_2_k36_bc_cifar10',
        }


class NIN(CIFAR10Based):

    def __init__(self, type='default', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            'default': 'nin_cifar10',
        }


class RoR3(CIFAR10Based):

    def __init__(self, type='56', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '56': 'ror3_56_cifar10',
            '110': 'ror3_110_cifar10',
            '164': 'ror3_164_cifar10',
        }


class RiR(CIFAR10Based):

    def __init__(self, type='default', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            'default': 'rir_cifar10',
        }


class ShakeShakeResNet(CIFAR10Based):

    def __init__(self, type='20_2x16d', split='val'):
        self.type = type
        super().__init__(split)

    @staticmethod
    def _get_choices():
        return {
            '20_2x16d': 'shakeshakeresnet20_2x16d_cifar10',
            '26_2x32d': 'shakeshakeresnet26_2x32d_cifar10',
        }
