'''
Base code taken from https://github.com/facebookresearch/open_lth/blob/main/models/cifar_vgg.py
'''

import torch.nn as nn
import torch.nn.functional as F

from models.base_model import BaseModel
from layers import ModuleInjection

class Model(BaseModel):
    """A VGG-style neural network designed for CIFAR-10."""

    class ConvModule(nn.Module):
        """A single convolutional module in a VGG network."""

        def __init__(self, in_filters, out_filters, keep_full_precision=False):
            super(Model.ConvModule, self).__init__()
            self.conv = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1, bias=False)
            self.bn = nn.BatchNorm2d(out_filters)

            self.conv, self.bn, self.act = ModuleInjection.get_conv_bn_act(self.conv, self.bn, keep_full_precision = keep_full_precision)

        def forward(self, x):
            return self.act(self.bn(self.conv(x)))

    def __init__(self, plan, keep_full_precision_list, outputs=10):
        super(Model, self).__init__()

        layers = []
        filters = 3
        counter = 0

        for spec in plan:
            if spec == 'M':
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                layers.append(Model.ConvModule(filters, spec, keep_full_precision_list[counter]))
                filters = spec
            counter = counter+1

        self.layers = nn.Sequential(*layers)
        self.fc = nn.Linear(512, outputs)

        prev = None
        for b in self.layers:
            if isinstance(b, Model.ConvModule):
                self.prev_module[b.bn] = prev
                prev = b.bn

    def forward(self, x):
        x = self.layers(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    @property
    def output_layer_names(self):
        return ['fc.weight', 'fc.bias']

    @staticmethod
    def is_valid_model_name(model_name):
        return (model_name.startswith('cifar_vgg_') and
                len(model_name.split('_')) == 3 and
                model_name.split('_')[2].isdigit() and
                int(model_name.split('_')[2]) in [11, 13, 16, 19])

    @staticmethod
    def get_model_from_config(config):
        model_name = config.model_name
        outputs = config.dataset_num_classes
        ModuleInjection.update_model_compression_strategy(config.model_compression_strategy)

        if not Model.is_valid_model_name(model_name):
            raise ValueError('Invalid model name: {}'.format(model_name))

        num = int(model_name.split('_')[2])
        if num == 11:
            plan = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512] 
        elif num == 13:
            plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512]
        elif num == 16:
            plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
        elif num == 19:
            plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
        else:
            raise ValueError('Unknown VGG model: {}'.format(model_name))
            
        keep_full_precision_list = []

        for i in range(num+1):
            if i in config.keep_full_precision:
                keep_full_precision_list.append(True)
            else:
                keep_full_precision_list.append(False)

        net = Model(plan, keep_full_precision_list, outputs)
        net.prunable_modules = ModuleInjection.prunable_modules

        return net

    @staticmethod
    def get_default_args(training_config):

        if training_config.dataset_name.startswith('cifar'):

            training_config.epochs = 160
            training_config.training_batch_size = 128
            training_config.test_batch_size = 256
            training_config.train_shuffle = True
            training_config.test_shuffle = False

            training_config.scheduler_type = "multi_step"
            training_config.scheduler_gamma = 0.1
            training_config.scheduler_milestones = [80, 120]

            training_config.optimizer_type = "sgd"
            training_config.optimizer_lr = 0.1
            training_config.optimizer_momentum = 0.9
            training_config.optimizer_weight_decay = 1e-4
            training_config.optimizer_no_decay = ["bias"]

        else:
            raise ValueError('Invalid dataset name: {}'.format(training_config.dataset_name))

        return training_config