import torch
import math

def get_conf_params(method="he"):
    coeff = 2.0
    spp_dim = [0, 1, 2, 3]
    init_dim = [0, 1, 2, 3]

    if method == "he":
        coeff = 2.0
        spp_dim = [0, 1, 2, 3]
        init_dim = [0, 1, 2, 3]
    elif method == "brock":
        coeff = 2.93388441385
        spp_dim = [0, 1, 2, 3]
        init_dim = [1, 2, 3]

    return {"c": coeff, "init": init_dim, "spp": spp_dim}

def _calculate_fan_in_and_fan_out(tensor):
    dimensions = tensor.dim()
    if dimensions < 2:
        raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")

    num_input_fmaps = tensor.size(1)
    num_output_fmaps = tensor.size(0)
    receptive_field_size = 1
    if tensor.dim() > 2:
        receptive_field_size = tensor[0][0].numel()
    fan_in = num_input_fmaps * receptive_field_size
    fan_out = num_output_fmaps * receptive_field_size

    return fan_in, fan_out

def _calculate_correct_fan(tensor, mode):
    mode = mode.lower()
    valid_modes = ['fan_in', 'fan_out']
    if mode not in valid_modes:
        raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))

    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    return fan_in if mode == 'fan_in' else fan_out

def kaiming_normal_residual_conv_(tensor, mode="fan_in", fan_expansion=2.0, coeff=2, dim=[0, 1, 2, 3]):
    fan = _calculate_correct_fan(tensor, mode)
    std = coeff / (fan * fan_expansion)
    std = math.sqrt(std)
    with torch.no_grad():
        weights = tensor.normal_(0, std)
        weights = tensor.div_(weights.std(dim=dim, keepdim=True))
        weights = tensor.mul_(std)
        weights = tensor.sub_(weights.mean(dim=dim, keepdim=True))
    return weights, fan

def xavier_normal_residual_conv_(tensor, mode="fan_in", fan_expansion=2.0, dim=[0, 1, 2, 3]):
    fan = _calculate_correct_fan(tensor, mode)
    std = 1.0 / (fan * fan_expansion)
    std = math.sqrt(std)
    with torch.no_grad():
        weights = tensor.normal_(0, std)
        weights = tensor.div_(weights.std(dim=dim, keepdim=True))
        weights = tensor.mul_(std)
        weights = tensor.sub_(weights.mean(dim=dim, keepdim=True))
    return weights, fan
