import logging
import math
from collections import OrderedDict
# from typing import Union

import torch
# import torch.nn as nn
from torch import nn
import torch.nn.functional as func
from torch.nn.modules.conv import _ConvNd


class BaseModule(nn.Module):

    # forward
    def forward(self, x):
        z = self.encode(x)
        logits = self.decode_clf(z)
        return logits

    def encode(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        z = torch.flatten(x, 1)
        return z

    def decode_clf(self, z):
        logits = self.classifier(z)
        return logits


def kaiming_uniform_in_(tensor, a=0, mode='fan_in', scale=1., nonlinearity='leaky_relu'):
    """Modified from torch.nn.init.kaiming_uniform_"""
    fan_in = nn.init._calculate_correct_fan(tensor, mode)
    fan_in *= scale
    gain = nn.init.calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan_in)
    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)


def scale_init_param(m, scale_in=1.):
    """Scale w.r.t. input dim."""
    if isinstance(m, (nn.Linear, _ConvNd)):
        kaiming_uniform_in_(m.weight, a=math.sqrt(5), scale=scale_in, mode='fan_in')
        if m.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
            fan_in *= scale_in
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(m.bias, -bound, bound)
    return m


class ScalableModule(BaseModule):
    def __init__(self):
        super(ScalableModule, self).__init__()


class DigitModel(ScalableModule):
    """
    Model for benchmark experiment on Digits.
    """
    input_shape = [None, 3, 28, 28]

    def __init__(self, num_classes=10, bn_type='bn', track_running_stats=True,
                 share_affine=True):
        super(DigitModel, self).__init__()
        bn_class = {
            '1d': nn.BatchNorm1d,
            '2d': nn.BatchNorm2d,
        }
        bn_kwargs = dict(
            track_running_stats=track_running_stats,
        )
        conv_layers = [64, 64, 128]
        fc_layers = [2048, 512]
        conv_layers = [int(l) for l in conv_layers]
        fc_layers = [int(l) for l in fc_layers]
        self.bn_type = bn_type

        self.conv1 = nn.Conv2d(3, conv_layers[0], 5, 1, 2)
        self.bn1 = bn_class['2d'](conv_layers[0], **bn_kwargs)

        self.conv2 = nn.Conv2d(conv_layers[0], conv_layers[1], 5, 1, 2)
        self.bn2 = bn_class['2d'](conv_layers[1], **bn_kwargs)

        self.conv3 = nn.Conv2d(conv_layers[1], conv_layers[2], 5, 1, 2)
        self.bn3 = bn_class['2d'](conv_layers[2], **bn_kwargs)

        self.fc1 = nn.Linear(conv_layers[2] * 7 * 7, fc_layers[0])
        self.bn4 = bn_class['1d'](fc_layers[0], **bn_kwargs)

        self.fc2 = nn.Linear(fc_layers[0], fc_layers[1])
        self.bn5 = bn_class['1d'](fc_layers[1], **bn_kwargs)

        self.fc3 = nn.Linear(fc_layers[1], num_classes)

        self.reset_parameters(inp_nonscale_layers=['conv1'])

    def forward(self, x):
        z = self.encode(x)
        return self.decode_clf(z)

    def encode(self, x):
        x = func.relu(self.bn1(self.conv1(x)))
        x = func.max_pool2d(x, 2)

        x = func.relu(self.bn2(self.conv2(x)))
        x = func.max_pool2d(x, 2)

        x = func.relu(self.bn3(self.conv3(x)))

        x = x.view(x.shape[0], -1)
        return x

    def decode_clf(self, x):
        x = self.fc1(x)
        x = self.bn4(x)
        x = func.relu(x)

        x = self.fc2(x)
        x = self.bn5(x)
        x = func.relu(x)

        logits = self.fc3(x)
        return logits

