import torch
import torch.nn as nn
import functools
from functools import partial


def get_act_scales(model, cali_data):
    act_scales = {}
    fp_inputs = {}
    act_cv = {}

    def stat_tensor(name, tensor):
        fp_inputs[name] = tensor
        in_channel = tensor.shape[1]
        tensor = torch.transpose(tensor, 0, 1)
        tensor = tensor.reshape(in_channel, -1).detach()
        
        mean = tensor.mean(dim=1)
        std = tensor.std(dim=1)

        mean[mean == 0] = 1e-6

        cv = std / mean
        cv = cv.abs()

        if name in act_cv:
            # If multiple batches, take max CV per channel
            act_cv[name] = torch.max(act_cv[name], cv)
        else:
            act_cv[name] = cv
        
        tensor = tensor.abs()
        comming_max = torch.max(tensor, dim=1)[0].float()
        if name in act_scales:
            act_scales[name] = torch.max(act_scales[name], comming_max)
        else:
            act_scales[name] = comming_max

    def stat_input_hook(m, x, y, name):
        if isinstance(x, tuple):
            x = x[0]
        stat_tensor(name, x)

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            hooks.append(
                m.register_forward_hook(functools.partial(stat_input_hook, name=name))
            )

    with torch.no_grad():
        model(cali_data[: 256].cuda())

    for h in hooks:
        h.remove()
    
    return act_scales, fp_inputs, act_cv


@torch.no_grad()
def adab_conv(qconv, act_weight_module, act_scales, act_cv, beta=0.5, a=0.5, b=0.9, module=None, firstlayer=False):
    alpha = torch.sigmoid(beta*act_cv).clamp(min=a, max=b)

    conv_weight = qconv.module.weight
    groups = qconv.module.groups
    device, dtype = conv_weight.device, conv_weight.dtype 

    out_channel = conv_weight.shape[0]
    in_channel = conv_weight.shape[1]
    shape = conv_weight.shape
    conv_weight = conv_weight.reshape(groups, int(out_channel/groups), in_channel, -1)
    conv_weight = torch.transpose(conv_weight, 1, 2)
    conv_weight = conv_weight.reshape(groups*in_channel, -1).abs().detach()
    weight_scale = torch.max(conv_weight, dim=1)[0].float()
    weight_scale = weight_scale.clamp(min=1e-5)

    act_scales = act_scales.to(device)
    scales = (
        (act_scales.pow(alpha) / weight_scale.pow(1 - alpha))
        .clamp(min=1e-5)
        .to(device)
        .to(dtype)
    )

    if act_weight_module.act_weight.data.shape != scales.view(1,-1,1,1).shape:
        act_weight_module.act_weight.data = act_weight_module.act_weight.data.repeat(scales.view(1,-1,1,1).shape)
    act_weight_module.act_weight.data.div_(scales.view(1,-1,1,1))
    
    if firstlayer:
        if module.shortcut_act_weight.data.shape != scales.view(1,-1,1,1).shape:
            module.shortcut_act_weight.data = module.shortcut_act_weight.data.repeat(scales.view(1,-1,1,1).shape)
        module.shortcut_act_weight.data.mul_(scales.view(1,-1,1,1))
    
    scales = scales.view(groups,in_channel,1,1)
    scales = scales.repeat(int(out_channel/groups), 1, shape[2], shape[3])
    
    qconv.module.weight.data.mul_(scales)