import numpy as np
import torch
import torch.nn as nn

from functools import partial

def backward_hook(gamma):
    # implement SGM through grad through ReLU
    def _backward_hook(module, grad_in, grad_out):
        if isinstance(module, nn.ReLU):
            # print('relu')
            return (gamma * grad_in[0],)
    return _backward_hook


def backward_hook_norm(module, grad_in, grad_out):
    # normalize the gradient to avoid gradient explosion or vanish
    std = torch.std(grad_in[0])
    return (grad_in[0] / std,)


def register_hook_for_resnet(model, arch, gamma):
    # There is only 1 ReLU in Conv module of ResNet-18/34
    # and 2 ReLU in Conv module ResNet-50/101/152
    if arch in ['resnet50', 'resnet101', 'resnet152']:
        gamma = np.power(gamma, 0.5)
    backward_hook_sgm = backward_hook(gamma)

    for name, module in model.named_modules():
        if 'act' in name and not '0.act' in name:
            # print('register!')
            # print(name)
            module.register_backward_hook(backward_hook_sgm)

        # e.g., 1.layer1.1, 1.layer4.2, ...
        # if len(name.split('.')) == 3:
        if len(name.split('.')) >= 2 and 'layer' in name.split('.')[-2]:
            # print(name)
            module.register_backward_hook(backward_hook_norm)


def register_hook_for_densenet(model, arch, gamma):
    # There are 2 ReLU in Conv module of DenseNet-121/169/201.
    gamma = np.power(gamma, 0.5)
    backward_hook_sgm = backward_hook(gamma)
    for name, module in model.named_modules():
        if 'relu' in name and not 'transition' in name:
            module.register_backward_hook(backward_hook_sgm)

def attn_drop_mask_grad(module, grad_in, grad_out, gamma):
            mask = torch.ones_like(grad_in[0]) * gamma
            return (mask * grad_in[0][:], )

def mlp_mask_grad(module, grad_in, grad_out, gamma):
    mask = torch.ones_like(grad_in[0]) * gamma
    return (mask * grad_in[0], grad_in[1])

def attn_mask_grad(module, grad_in, grad_out, gamma):
    mask = torch.ones_like(grad_in[0]) * gamma
    return (mask * grad_in[0], grad_in[1])

def register_hook_for_vit(model, arch, gamma=0.5, sgm_control='1,0'):
    drop_hook_func = partial(attn_drop_mask_grad, gamma=0)
    mlp_hook_func = partial(mlp_mask_grad, gamma=0.5)
    attn_hook_func = partial(attn_mask_grad, gamma=0.5)
    sgm_control = sgm_control.split(',')

    for i in range(12):
        model.blocks[i].attn.attn_drop.register_backward_hook(drop_hook_func)
        if sgm_control[0] == '1':
            model.blocks[i].mlp.register_backward_hook(mlp_hook_func)
        if sgm_control[1] == '1':
            model.blocks[i].attn.qkv.register_backward_hook(attn_hook_func)