import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math

 
class SkewSymmetricLinear(nn.Linear):
    def __init__(self, in_features, device=None, dtype=None):
        super().__init__(in_features, in_features, bias=False, device=device, dtype=dtype)
        
    def make_skew_symmetric(self, A):
        """ Returns a new matrix that is constrained to be skew-symmetric """
        return (A - A.T) / 2
    
    def forward(self, x):
        weight = self.make_skew_symmetric(self.weight)
        return F.linear(x, weight)
    

class BlockDiagonalMatrixMultiplication(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear
        self.n = int(math.sqrt(linear.in_features))

    def forward(self, x, shift, vertical):
        if vertical: x = torch.rot90(x, k=1, dims=(-2, -1))
        x = x.roll(shifts=-shift, dims=-1)
        x = rearrange(x, 'n 1 (h n1) (w n2) -> n h w (n1 n2)', n1=self.n, n2=self.n)
        x = self.linear(x)
        x = rearrange(x, 'n h w (n1 n2) -> n 1 (h n1) (w n2)', n1=self.n, n2=self.n)
        x = x.roll(shifts=shift, dims=-1) 
        if vertical: x = torch.rot90(x, k=-1, dims=(-2, -1))
        return x
    

class DivFree(nn.Module):
    def __init__(self, model, n=4):
        super().__init__()
        self.model = model
        self.n = n
        self.skewsymmetric = BlockDiagonalMatrixMultiplication(SkewSymmetricLinear(n**2))
        self.linear = BlockDiagonalMatrixMultiplication(nn.Linear(n**2, n**2, bias=False))

    def forward(self, x, sigma=None):
        _, _, h, w = x.size()
        divisor = 8
        r1, r2 = h % divisor, w % divisor
        x = F.pad(x, pad=(0, divisor-r2 if r2 > 0 else 0, 0, divisor-r1 if r1 > 0 else 0), mode='reflect')
        
        output = torch.zeros_like(x)
        with torch.enable_grad():
            x.requires_grad_(True)
            out = self.model(x)
            for i in range(self.n):
                for v in [True, False]:
                    y = self.linear(x, i, v)
                    grad = torch.autograd.grad(torch.sum(y**2) - torch.sum((y - out)**2), x, create_graph=True)[0] / 2
                    output += self.skewsymmetric(grad, i, v)
        return output[..., :h, :w]