import torch
import torch.nn as nn
import numpy as np


class SkipConv2d(nn.Module):
  def __init__(self, conv2d, gamma) -> None:
    super().__init__()
    self.conv2d, self.skip_conv2d = self._skip_conv(conv2d)
    self._gamma = gamma
  
  @property
  def gamma(self):
    return self._gamma
  
  @gamma.setter
  def gamma(self, val):
    self._gamma = np.power(val, 0.5)

  def _skip_conv(self, conv):
    # 根据kernel_size 减掉中间为 1 / (in_channels // groups) 
    skip_kernel = torch.zeros_like(conv.weight)
    k_size = conv.kernel_size
    if isinstance(k_size, int):
      k = k_size // 2
      skip_kernel[:, :, k, k] = 1 / (conv.in_channels // conv.groups)
    else:
      skip_kernel[:, :, k_size[0] // 2, k_size[1] // 2] = 1 / (conv.in_channels // conv.groups)
    conv.weight.data.sub_(skip_kernel)

    skip_conv = None
    if conv.groups != 1:
      skip_conv = torch.nn.Conv2d(
        conv.in_channels, conv.out_channels, conv.kernel_size,
        stride=conv.stride, padding=conv.padding, dilation=conv.dilation,
        groups=conv.groups, bias=False, padding_mode=conv.padding_mode 
      )
      skip_conv.weight.data.copy_(skip_kernel)
    return conv, skip_conv

  def forward(self, *input, **kwargs):
    x = input[0]
    out = self.conv2d(*input, **kwargs)

    if self.skip_conv2d is not None:
      skip = self.skip_conv2d(*input, **kwargs)
    else:
      skip = x.mean(dim=1, keepdim=True)

    out = skip + self.gamma * out + (1 - self.gamma) * out.detach()
    return out


def scm_for_resnet(model, arch, gamma, scale, device):
  default_gamma = gamma
  print(arch, '\n', model)

  for _, (name, module) in enumerate(model.named_modules()):
      if not (isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d)) or \
        module.stride != (1, 1) or \
        module.kernel_size == (1, 1) or \
        '.0.conv' in name or \
        not [i * 2 + 1 for i in module.padding] == [i for i in module.kernel_size]:
        continue

      if 'layer1' in name:
        gamma = 0.89 / 2
      elif 'layer2' in name:
        gamma = 0.88 / 2
      elif 'layer3' in name:
        gamma = 0.7 / 2
      elif 'layer4' in name:
        gamma = 0.2
      else:
        gamma = default_gamma

      print(f"SCM skip {name} with gamma={gamma}")
      tokens = name.split('.')
      cur_mod = model
      for t in tokens[:-1]:
        cur_mod = getattr(cur_mod, t)
      skip_module = SkipConv2d(module, gamma / scale).to(device)
      setattr(cur_mod, tokens[-1], skip_module)
  print(model)


def scm_for_densenet(model, arch, gamma, scale, device):
  default_gamma = 0.2 
  print(model)

  layer_i = 0
  for _, (name, module) in enumerate(model.named_modules()):
      if not (isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d)) or \
        module.stride != (1, 1) or \
        module.kernel_size == (1, 1) or \
        not [i * 2 + 1 for i in module.padding] == [i for i in module.kernel_size]:
        continue

      if f'denseblock4' not in name:
        continue

      layer_i += 1
      gamma = default_gamma + (32 - layer_i) * 0.005
      gamma = default_gamma
      print(f"SCM skip {name} with gamma={gamma}")

      tokens = name.split('.')
      cur_mod = model
      for t in tokens[:-1]:
        cur_mod = getattr(cur_mod, t)
      skip_module = SkipConv2d(module, gamma / scale).to(device)
      setattr(cur_mod, tokens[-1], skip_module)
  print(model)


def scm_for_inc(model, arch, gamma, scale, device):
  default_gamma = gamma
  gamma = 0.5 
  print(model)

  for _, (name, module) in enumerate(model.named_modules()):
      if not (isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d)) or \
        module.stride != (1, 1) or \
        module.kernel_size == (1, 1) or \
        not [i * 2 + 1 for i in module.padding] == [i for i in module.kernel_size]:
        continue


      tokens = name.split('.')

      if int(tokens[2]) < 15:
        continue
      print(f"SCM skip {name} with gamma={gamma}")

      cur_mod = model
      for t in tokens[:-1]:
        cur_mod = getattr(cur_mod, t)
      skip_module = SkipConv2d(module, gamma / scale).to(device)
      setattr(cur_mod, tokens[-1], skip_module)

  print(model)