import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def ONI_linear_normalize(weight, T=10, norm_groups=1, eps=1e-5):
    assert weight.shape[0] % norm_groups == 0
    Z = weight.view(norm_groups, weight.shape[0]//norm_groups, -1) # 
    Zc = Z - Z.mean(dim=-1, keepdim=True)
    S = torch.matmul(Zc, Zc.transpose(1, 2))
    eye = torch.eye(S.shape[-1]).to(S).expand(S.shape)
    S = S + eps*eye
    norm_S = S.norm(p='fro', dim=(1, 2), keepdim=True)
    S = S.div(norm_S)
    I = torch.eye(S.shape[-1]).to(S).expand(S.shape)
    Y, B = S, I
    for t in range(T):
        Y_new = 0.5 * Y @ (3*I - B@Y)
        B_new = 0.5 * (3*I - B@Y) @ B
        Y, B = Y_new, B_new
    W = (B @ Zc).div_(norm_S.sqrt())
    return W.view_as(weight)

class ONI_Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, scale=1.0, iter_T=10):
        super(ONI_Linear, self).__init__(in_features, out_features, bias)
        assert in_features >= out_features
        self.scale = scale
        self.iter_T = iter_T

    def forward(self, input_f:torch.Tensor) -> torch.Tensor:
        weight_q = self.normed_weight()
        return F.linear(input_f, weight_q, self.bias)

    def normed_weight(self):
        weight_q = ONI_linear_normalize(self.weight, T=self.iter_T)
        weight_q = weight_q * self.scale
        return weight_q

class GroupSort(nn.Module):
    def __init__(self, dim=1):
        super(GroupSort, self).__init__()
        self.dim = dim

    def forward(self, x):
        #print (x.shape, self.dim)
        a, b = x.split(x.size(self.dim) // 2, self.dim)
        a, b = torch.max(a, b), torch.min(a, b)
        return torch.cat([a, b], dim=self.dim)
