import torch
import numpy as np
import torch.nn as nn


def backward_hook(gamma):
    # implement SGM through grad through ReLU
    def _backward_hook(module, grad_in, grad_out):
        if isinstance(module, nn.ReLU):
            return (gamma * grad_in[0],)
    return _backward_hook


def backward_hook_norm(module, grad_in, grad_out):
    # normalize the gradient to avoid gradient explosion or vanish
    std = torch.std(grad_in[0])
    return (grad_in[0] / std,)


class SkipRELU(nn.Module):
  def __init__(self, relu, gamma) -> None:
     super().__init__()
     self.relu = relu
     self._gamma = gamma
  
  @property
  def gamma(self):
    return self._gamma
  
  @gamma.setter
  def gamma(self, val):
    self._gamma = np.power(val, 0.5)
  
    
  def forward(self, *input, **kwargs):
    out = self.relu(*input, **kwargs)
    out = self.gamma * out + (1 - self.gamma) * out.detach()
    return out

def sgm_for_resnet(model, arch, gamma, device):
  if arch in ['resnet50', 'resnet101', 'resnet152']:
    gamma = np.power(gamma, 0.5)

  for name, module in model.named_modules():
    if 'relu' in name and not '.0.relu' in name:
      print(f"SGM skip {name} with gamma={gamma}")
      tokens = name.split('.')
      cur_mod = model
      for t in tokens[:-1]:
        cur_mod = getattr(cur_mod, t)
      skip_module = SkipRELU(module, gamma).to(device)
      setattr(cur_mod, tokens[-1], skip_module)
  print(model)

def sgm_for_densenet(model, arch, gamma, device):
  gamma = np.power(gamma, 0.5)

  for name, module in model.named_modules():
    if 'relu' in name and not 'transition' in name:
      print(f"SGM skip {name} with gamma={gamma}")
      tokens = name.split('.')
      cur_mod = model
      for t in tokens[:-1]:
        cur_mod = getattr(cur_mod, t)
      skip_module = SkipRELU(module, gamma).to(device)
      setattr(cur_mod, tokens[-1], skip_module)
  print(model)
