"""
    Based on code taken from https://github.com/facebookresearch/open_lth

    Copyright (c) Facebook, Inc. and its affiliates.

    This source code is licensed under the MIT license found in the
    LICENSE file in the root directory of this source tree.
    Adapted from https://github.com/ganguli-lab/Synaptic-Flow/blob/master/Models/lottery_vgg.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from Layers.layers import conv3x3_block, Conv2d, BatchNorm2d, Linear, sparse_initialize, ConvBlock


class VGG(nn.Module):
    """A VGG-style neural network designed for CIFAR-10."""
    def __init__(self, plan, conv, num_classes=10, dense_classifier=False):
        super(VGG, self).__init__()
        layer_list = []
        filters = 3

        for spec in plan:
            if spec == 'M':
                layer_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                layer_list.append(conv(filters, spec))
                filters = spec

        self.layers = nn.Sequential(*layer_list)

        self.fc = Linear(512, num_classes)

        self._initialize_weights()
        self.set_prune_types()

    def forward(self, x):
        x = self.layers(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    @torch.no_grad()
    def _initialize_weights(self, sparse_init=True):
        # for module in self.modules():
        #     if isinstance(module, Conv2d):
        #         module.weight.fill_(1)
        #         module.weight += torch.randn_like(module.weight) * 1e-6
        if sparse_init:
            sparse_initialize(self)
        else:
            for m in self.modules():
                if isinstance(m, (layers.Linear, nn.Linear, layers.Conv2d)):
                    nn.init.kaiming_normal_(m.weight)
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, layers.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def clean(self):
        with torch.no_grad():
            self.cuda()
            self._initialize_weights()
            self(torch.randn(64, 3, 32, 32).cuda())

    def set_prune_types(self):
        self.pruned_types = {}
        unit_list = []
        for name, module in self.named_modules():
            if isinstance(module, ConvBlock):
                unit_list.append(module)
        length = len(unit_list)
        self.pruned_types[self.layers[0]] = 'in'
        for index, unit in enumerate(unit_list):
            if index == 0:
                self.pruned_types[unit] == 'in'
            else:
                self.pruned_types[unit] = 'vgg_out'


def _plan(num):
    if num == 16:
        plan = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
    else:
        raise ValueError('Unknown VGG model: {}'.format(num))
    return plan


def _vgg(plan, conv_block, num_classes):
    model = VGG(plan, conv=conv_block, num_classes=10)
    return model


def vgg16_bn(groups=1, width_factor=1, **kwargs):
    assert groups == 1 and width_factor == 1

    plan = _plan(16)
    return _vgg(plan, conv3x3_block, 10)


if __name__ == "__main__":
    import torch
    model = vgg16_bn()
    x = torch.randn(1, 3, 32, 32)
    y = model(x)
    y.sum().backward()
    print(y)
    print(model)