"""
Orthogonalization by Newton’s Iteration
"""
import torch.nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.autograd import Variable
from typing import List
from torch.autograd.function import once_differentiable

__all__ = ['WN_Conv2d', 'OWN_Conv2d', 'ONI_Conv2d','ONI_ConvTranspose2d',
           'ONI_Linear']

#  norm funcitons--------------------------------


class IdentityModule(torch.nn.Module):
     def __init__(self, *args, **kwargs):
          super(IdentityModule, self).__init__()

     def forward(self, input: torch.Tensor):
          return input

class WNorm(torch.nn.Module):
     def forward(self, weight):
          weight_ = weight.view(weight.size(0), -1)
          #std = weight_.std(dim=1, keepdim=True) + 1e-5
          norm = weight_.norm(dim=1, keepdim=True) + 1e-5
          weight_norm = weight_ / norm
          return weight_norm.view(weight.size())



class ONINorm(torch.nn.Module):
     def __init__(self, T=5, norm_groups=1, *args, **kwargs):
          super(ONINorm, self).__init__()
          self.T = T
          self.norm_groups = norm_groups
          self.eps = 1e-5

     def matrix_power3(self, Input):
          B=torch.bmm(Input, Input)
          return torch.bmm(B, Input)

     def forward(self, weight: torch.Tensor):
          # print(weight.shape)
          assert weight.shape[0] % self.norm_groups == 0
          Z = weight.view(self.norm_groups, weight.shape[0] // self.norm_groups, -1)  # type: torch.Tensor
          Zc = Z - Z.mean(dim=-1, keepdim=True)
          S = torch.matmul(Zc, Zc.transpose(1, 2))
          eye = torch.eye(S.shape[-1]).to(S).expand(S.shape)
          S = S + self.eps*eye
          # print(S.shape)
          # frobenious norm
          norm_S = S.norm(p='fro', dim=(1, 2), keepdim=True)
          # spectral norm
          # norm_S = torch.linalg.norm(S,dim=(1,2),keepdim=True,ord=2)
          S = S.div(norm_S)
          B = [torch.Tensor([]) for _ in range(self.T + 1)]
          B[0] = torch.eye(S.shape[-1]).to(S).expand(S.shape)
          for t in range(self.T):
               #B[t + 1] = torch.baddbmm(1.5, B[t], -0.5, torch.matrix_power(B[t], 3), S)
               B[t + 1] = torch.baddbmm(1.5, B[t], -0.5, self.matrix_power3(B[t]), S)
          W = B[self.T].matmul(Zc).div_(norm_S.sqrt())
          #print(W.matmul(W.transpose(1,2)))
          # W = oni_py.apply(weight, self.T, ctx.groups)
          return W.view_as(weight)

     def extra_repr(self):
          fmt_str = ['T={}'.format(self.T)]
          if self.norm_groups > 1:
               fmt_str.append('groups={}'.format(self.norm_groups))
          return ', '.join(fmt_str)