import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class LinReluWOSGM(nn.Module):
    def __init__(self, relu, gamma=1.0) -> None:
       super().__init__()
       # module.relu(inplace=True), 不能使用
       self.relu = relu
       self._gamma = gamma
  
    @property
    def gamma(self):
      return self._gamma
    
    @gamma.setter
    def gamma(self, val):
      self._gamma = np.power(val, 0.5)

    def forward(self, x):

        x_size = x.detach().bool().sum()
        # x = F.relu(x)
        x = x + F.relu(-x).detach()

        alpha = x.bool().sum().detach() / x_size 
        out = alpha * x + (1 - alpha) * x.detach()

        return out


class LinRelu(nn.Module):
    def __init__(self, relu, gamma=1.0) -> None:
       super().__init__()
       # module.relu(inplace=True), 不能使用
       self.relu = relu
       self._gamma = gamma
  
    @property
    def gamma(self):
      return self._gamma
    
    @gamma.setter
    def gamma(self, val):
      self._gamma = np.power(val, 0.5)

    def forward(self, x):
        x_relu = F.relu(x)

        x_size = x.detach().bool().sum()
        x = x + F.relu(-x).detach()

        alpha = x.bool().sum().detach() / x_size 
        out = alpha * x + (1 - alpha) * x.detach()

        out = self.gamma * out + (1 - self.gamma) * x_relu
        return out


def linbp_for_resnet50(model, arch, device):
  linbp_layer = '3_1'
  jj, kk = map(int, linbp_layer.split('_'))

  print(arch, '\n', model)

  for _, (name, module) in enumerate(model.named_modules()):
      if not (isinstance(module, nn.ReLU) and 'layer' in name):
        continue

      tokens = name.split('.')
      jj_now, kk_now = int(tokens[1][-1]), int(tokens[2])
      if jj < jj_now or (jj == jj_now and kk_now >= kk):
        cur_mod = model
        for t in tokens[:-1]:
          cur_mod = getattr(cur_mod, t)
        skip_module = LinRelu(module).to(device)
        setattr(cur_mod, tokens[-1], skip_module)
  print(model)