from __future__ import print_function

import math
import torch.nn as nn
import torch
from .utils.utils import get_activation_function, make_conv_block, make_fc_block


default_cfg = {
    '5': [64, 64],
    '7': [64, 64, 'M', 128, 128],
    '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
}

class VGG(nn.Module):
    def __init__(self, activation_type, num_classes=10, depth=16,
                 oper_order='cba', dataset='cifar', bn_momentum=0.1,
                 cut_block=0, tau=0):
        super(VGG, self).__init__()
        cfg = default_cfg[depth]
        cut_cnt = int(cut_block)
        MP_count = 0

        while(cut_cnt != 0):
            last_cfg = cfg.pop(-1)

            if last_cfg != 'M':
                cut_cnt-=1
            else:
                MP_count+=1

        cfg = cfg + ['M']*MP_count
        cfg = [int(c) if c != 'M' else c for c in cfg]

        self.bn_momentum = bn_momentum
        
        self.activation_generator = get_activation_function(activation_type, tau=tau)
        self.oper_order = oper_order
        self.dataset = dataset
        self.oper_order = {'full': list(oper_order)[:], 'front2': list(oper_order)[:-1]}

        self.cfg = cfg
        if 'ImageNet' == dataset:
            if depth == '16':
                self.cfg = self.cfg + ['M']

        self.feature = self.make_conv_layers()
        self.avgpool, self.classifier = self.make_fc_layers(num_classes)

        self._initialize_weights()

    def make_conv_layers(self):
        layers = []
        in_channels = 3
        for v in self.cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                stride = 1

                layers += [make_conv_block(in_channels, v, self.activation_generator, kernel_size=3, stride=stride,
                                            padding=1, oper_order=self.oper_order['full'],
                                            bn_momentum=self.bn_momentum)]
                in_channels = v

        return nn.Sequential(*layers)

    def make_fc_layers(self, out_classes):
        layers = []
        last_channel = [c for c in self.cfg if (c != 'M')][-1]

        if 'cifar' in self.dataset:
            avgpool = nn.AvgPool2d(2)
            scaled_channel = int(512)

            layers = [make_fc_block(activations_in=last_channel, activations_out=scaled_channel,
                                    activation_generator=self.activation_generator,
                                    oper_order=self.oper_order['full'],
                                    bn_momentum=self.bn_momentum)]

            layers += [make_fc_block(activations_in=scaled_channel, activations_out=out_classes,
                                    activation_generator=self.activation_generator,
                                    oper_order='f',
                                    bn_momentum=self.bn_momentum)]

        elif 'tinyImageNet' in self.dataset:
            avgpool = nn.MaxPool2d(kernel_size=2, stride=2)
            dim = last_channel * 2 * 2

            fc_order = "faD"

            layers += [make_fc_block(activations_in=dim, activations_out=dim//2,
                                    activation_generator=self.activation_generator,
                                    oper_order=fc_order, bn_momentum=self.bn_momentum)]
            layers += [make_fc_block(activations_in=dim//2, activations_out=dim//2,
                                    activation_generator=self.activation_generator,
                                    oper_order=fc_order, bn_momentum=self.bn_momentum)]
            layers += [make_fc_block(activations_in=dim//2, activations_out=out_classes,
                                    activation_generator=self.activation_generator,
                                    oper_order="f", bn_momentum=self.bn_momentum)]

        elif 'ImageNet' in self.dataset:
            avgpool = nn.AdaptiveAvgPool2d((7, 7))

            fc_order = "faD"

            layers += [make_fc_block(activations_in=last_channel * 7 * 7, activations_out=4096,
                                    activation_generator=self.activation_generator,
                                    oper_order=fc_order, bn_momentum=self.bn_momentum)]
            layers += [make_fc_block(activations_in=4096, activations_out=4096,
                                    activation_generator=self.activation_generator,
                                    oper_order=fc_order, bn_momentum=self.bn_momentum)]
            layers += [make_fc_block(activations_in=4096, activations_out=out_classes,
                                    activation_generator=self.activation_generator,
                                    oper_order="f", bn_momentum=self.bn_momentum)]

        return avgpool, nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)

        if self.avgpool is not None:
            x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)

        return y

    def get_minmax(self, x, block_output=False, channel_flag=False):
        minmax = []
        def _get_channel_minmax(x):
            cbhw = x.transpose(0, 1)
            flatten_channel = cbhw.reshape(cbhw.size(0), -1)

            c_max, _ = torch.max(flatten_channel, dim=1)
            c_min, _ = torch.min(flatten_channel, dim=1)

            channel_minmax = torch.stack([c_min, c_max], dim=1)

            return channel_minmax
        
        def _get_layer_minmax(x):
            layer_minmax = torch.stack([x.min(), x.max()], dim=0)

            return layer_minmax

        if block_output:
            for idx, module in enumerate(self.feature):
                if isinstance(module, nn.Sequential):
                    x = module(x)

                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)

                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)

                        minmax.append(layer_minmax)
                else:
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)
            for idx, module in enumerate(self.classifier):
                if isinstance(module, nn.Sequential):
                    x = module(x)

                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)

                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)

                        minmax.append(layer_minmax)
                else:
                    x = module(x)

        # get activation function's output
        else:
            acti_type = self.activation_generator.__next__()
            operation_cnt = 0

            for idx, module in enumerate(self.feature.modules()):
                if isinstance(module, type(acti_type)):
                    x = module(x)

                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)

                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)

                        minmax.append(layer_minmax)
                elif not isinstance(module, nn.Sequential):
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)

            for idx, module in enumerate(self.classifier.modules()):
                if isinstance(module, type(acti_type)):
                    x = module(x)

                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)

                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)

                        minmax.append(layer_minmax)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)

        return minmax

    def get_activation(self, x, target='activation'): 
        features = []

        if target == 'weights':
            for module in self.feature.modules():

                if isinstance(module, nn.Conv2d):
                    x = module(x)
                    features.append(x)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)

            for module in self.classifier.modules():
                if isinstance(module, nn.Linear):
                    x = module(x)
                    features.append(x)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)
        # get activation function's output
        elif target == 'activation':
            acti_type = self.activation_generator.__next__()

            for module in self.feature.modules():
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    features.append(x)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)

            for module in self.classifier.modules():
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    features.append(x)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)
        # get block output
        else:
            for module in self.feature:
                if isinstance(module, nn.Sequential):
                    x = module(x)

                    features.append(x)
                else:
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)

            for module in self.classifier:
                if isinstance(module, nn.Sequential):
                    x = module(x)

                    # Excluding the last classifier layer
                    if module != self.classifier[-1]:
                        features.append(x)
                else:
                    x = module(x)

        return features

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                if m.weight is not None:
                    m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                if m.bias is not None:
                    m.bias.data.zero_()
