import torch
import torch.nn as nn
import math
from ..utils.parsing import get_class
from . import quantized


class ModelConfig():
    ''' Contains the layers used for the model
        if @all is set, its value will override @conv & @act
    '''
    def __init__(self, bn=None, act=None, conv=None, fc=None):
        self.conv = (lambda *args, **kwargs: get_class(conv, quantized, *args, **kwargs))\
            if conv is not None else nn.Conv2d

        # TODO: BN not supported yet -> should add instance BN in the future
        self.bn = nn.BatchNorm2d

        self.act = (lambda *args, **kwargs: get_class(act, quantized, *args, **kwargs))\
            if act is not None and act != "" else nn.ReLU

        self.fc = (lambda *args, **kwargs: get_class(fc, quantized, *args, **kwargs))\
            if fc is not None else nn.Linear


@torch.no_grad()
def init_weights(model, cfg):
    params = cfg["training"].get("weight_init", {})
    pretrained = params.get("pretrained", "")
    if pretrained != "":  # Pretrained weights take priority over init
        device = torch.device(
            "cuda" if torch.cuda.device_count() > 0 else "cpu"
        )
        pretrained_model = torch.load(pretrained, map_location=device)
        for pre in list(pretrained_model.keys()):
            if ".total_ops" in pre or ".total_params" in pre:
                del pretrained_model[pre]
        # FIXME : don't use strict
        load_result = model.load_state_dict(pretrained_model, strict=False)
        if len(load_result.missing_keys) > 0 or\
           len(load_result.unexpected_keys) > 0:
            print("WARNING: missing or unexpected keys found")
            print("Missing:\n", load_result.missing_keys)
            print("Unexpected:\n", load_result.unexpected_keys)
        return

    # This part fixes pytorch buggy default implementation
    act = cfg["network"].get("model_cfg", {}).get("act", "relu")
    if "leaky" in act:
        neg_slope = 0.01
        nonlin = "leaky_relu"
        sampling = "kaiming"
    elif "relu" in act:
        neg_slope = 0
        nonlin = "relu"
        sampling = "kaiming"
    elif "tanh" in act:
        neg_slope = 0
        nonlin = "tanh"
        sampling = "kaiming"
    else:
        print(f"Activation of type {act} is not supported yet")
    # Divide by sqrt(2) to support pytorch's stupid way of implementing xavier
    gain = nn.init.calculate_gain(nonlin, neg_slope)

    # Override default params
    gamma = params.get("gamma", 1.0)
    sampling = params.get("sampling", sampling)
    distribution = params.get("distribution", "normal")
    fan_mode = params.get("fan_mode", "fan_in")
    weight_fx = params.get("weight_fx", "pytorch")
    if weight_fx == "pytorch" and sampling == "xavier":  # Fix pytorch
        gain = 1.0
    gain = params.get("gain", gain)

    assert weight_fx in ["pytorch", "custom"]
    assert sampling in ["kaiming", "xavier"]
    assert distribution in ["normal", "uniform"]
    assert fan_mode in ["fan_in", "fan_out", "fan_avg"]

    def pytorch_weights_init(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            if sampling == "kaiming":
                _args = {
                    'tensor': m.weight, 'a': neg_slope,
                    'nonlinearity': nonlin, 'mode': fan_mode
                }
                if distribution == "uniform":
                    nn.init.kaiming_uniform_(**_args)
                else:
                    nn.init.kaiming_normal_(**_args)
            elif sampling == "xavier":
                _args = {'tensor': m.weight, 'gain': gain}
                if distribution == "uniform":
                    nn.init.xavier_uniform_(**_args)
                else:
                    nn.init.xavier_normal_(**_args)
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.fill_(gamma)

        if hasattr(m, "bias") and hasattr(m.bias, "data"):
            m.bias.zero_()

    def custom_weights_init(m):
        # This custom part does things by the book and mirrors Keras'
        # implementation instead of the wonky pytorch one
        # Support for depthwise convolutions has also been added
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            if isinstance(m, nn.Conv2d):
                ksize = m.kernel_size[0] * m.kernel_size[1]
                ksize = ksize / m.groups
                fan_out = m.out_channels * ksize
                fan_in = m.in_channels * ksize
            else:
                fan_out = m.out_features
                fan_in = m.in_features
            fan_avg = (fan_in + fan_out)/2

            if sampling == "xavier":
                std = gain/math.sqrt(fan_in+fan_out)
            elif sampling == "kaiming":
                fan = {
                    "fan_in": fan_in, "fan_out": fan_out, "fan_avg": fan_avg
                }[fan_mode]
                std = gain/math.sqrt(fan)

            if distribution == "normal":
                m.weight.normal_(0, std)
            else:
                limit = math.sqrt(3)*std
                m.weight.uniform_(-limit, limit)

        elif isinstance(m, nn.BatchNorm2d):
            m.weight.fill_(gamma)

        if hasattr(m, "bias") and hasattr(m.bias, "data"):
            m.bias.zero_()
    if "seed" in params:
        torch.manual_seed(params["seed"])
    model.apply(
        pytorch_weights_init if weight_fx == "pytorch" else custom_weights_init
    )
