# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
import torch.nn.functional as F

from foundations import hparams
from lottery.desc import LotteryDesc
from models import base
from pruning import sparse_global


class Model(base.Model):
    """A VGG-style neural network designed for CIFAR-10."""

    class ConvModule(nn.Module):
        """A single convolutional module in a VGG network."""

        def __init__(self, in_filters, out_filters):
            super(Model.ConvModule, self).__init__()
            self.conv = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1)
            self.bn = nn.BatchNorm2d(out_filters)

        def forward(self, x):
            return F.relu(self.bn(self.conv(x)))

    def __init__(self, plan, initializer, outputs=10):
        super(Model, self).__init__()

        layers = []
        filters = 3

        for spec in plan:
            if spec == 'M':
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                layers.append(Model.ConvModule(filters, spec))
                filters = spec

        self.layers = nn.Sequential(*layers)
        # self.fc = nn.Linear(512, outputs)
        self.fc = nn.Linear(plan[-1], outputs)
        self.criterion = nn.CrossEntropyLoss()

        self.apply(initializer)

    def forward(self, x):
        x = self.layers(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    @property
    def output_layer_names(self):
        return ['fc.weight', 'fc.bias']

    @staticmethod
    def is_valid_model_name(model_name):
        return (model_name.startswith('cifar_vgg_') and
                (len(model_name.split('_')) == 3 or len(model_name.split('_')) == 4) and
                model_name.split('_')[2].isdigit() and
                (len(model_name.split('_')) == 3 or model_name.split('_')[3].isdigit()) and
                int(model_name.split('_')[2]) in [11, 13, 16, 19])

    @staticmethod
    def get_model_from_name(model_name, initializer, outputs=10):
        if not Model.is_valid_model_name(model_name):
            raise ValueError('Invalid model name: {}'.format(model_name))

        outputs = outputs or 10

        num = int(model_name.split('_')[2])
        width = int(model_name.split('_')[3]) if len(model_name.split('_')) == 4 else 64
        if num == 11:
            # plan = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512]
            plan = [width, 'M', width*2, 'M', width*4, width*4, 'M', width*8, width*8, 'M', width*8, width*8]
        elif num == 13:
            # plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512]
            plan = [width, width, 'M', width*2, width*2, 'M', width*4, width*4, 'M', width*8, width*8, 'M', width*8, width*8]
        elif num == 16:
            # plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
            plan = [width, width, 'M', width*2, width*2, 'M', width*4, width*4, width*4, 'M', width*8, width*8, width*8, 'M', width*8, width*8, width*8]
        elif num == 19:
            # plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
            plan = [width, width, 'M', width*2, width*2, 'M', width*4, width*4, width*4, width*4, 'M', width*8, width*8, width*8, width*8, 'M', width*8, width*8, width*8, width*8]
        else:
            raise ValueError('Unknown VGG model: {}'.format(model_name))

        return Model(plan, initializer, outputs)

    @property
    def loss_criterion(self):
        return self.criterion

    @staticmethod
    def default_hparams():
        model_hparams = hparams.ModelHparams(
            model_name='cifar_vgg_16',
            model_init='kaiming_normal',
            batchnorm_init='uniform',
        )

        dataset_hparams = hparams.DatasetHparams(
            dataset_name='cifar10',
            batch_size=128
        )

        training_hparams = hparams.TrainingHparams(
            optimizer_name='sgd',
            momentum=0.9,
            milestone_steps='80ep,120ep',
            lr=0.1,
            gamma=0.1,
            weight_decay=1e-4,
            training_steps='160ep'
        )

        pruning_hparams = sparse_global.PruningHparams(
            pruning_strategy='sparse_global',
            pruning_fraction=0.2,
            # pruning_layers_to_ignore='fc.weight'
        )

        return LotteryDesc(model_hparams, dataset_hparams, training_hparams, pruning_hparams)
