import torch
import torch.nn as nn
import torch.nn.functional as F

from .conv import Conv2d
from .head import head


@torch.compile
def halfabs(x):
    d = x.shape[-1] * 3 // 4
    x = torch.cat([x[..., :d].abs(), x[..., d:]], dim=-1)
    return x



@torch.compile
def shift_fn(x):

    c = x.shape[3]
    a0, a1, a2, a3, a4 = torch.split(
        x, [c - c // 4, c // 16, c // 16, c // 16, c // 16], dim=3)

    a1 = torch.roll(a1, dims=1, shifts=1)
    a2 = torch.roll(a2, dims=1, shifts=-1)
    a3 = torch.roll(a3, dims=2, shifts=1)
    a4 = torch.roll(a4, dims=2, shifts=-1)

    x = torch.cat([a0, a1, a2, a3, a4], dim=3)
    return x

class lip_layers(nn.Module):

    def __init__(self,
                 depth=12,
                 width=2048,
                 ):
        super(lip_layers, self).__init__()
        self.depth = depth
        self.width = width

        device = "cuda" if torch.cuda.is_available() else "cpu"
        weights = torch.randn(depth, width, width).to(device) / width / 2
        weights = torch.matrix_exp(weights - weights.mT).cpu()
        torch.cuda.empty_cache()

        rotation = [torch.eye(width)[torch.randperm(width)] for _ in range(depth)]
        rotation = torch.stack(rotation).to(weights.dtype)
        weights = torch.cat([weights, rotation])

        self.weights = nn.Parameter(weights)

        self.bias = nn.Parameter(torch.zeros(depth, width))

        self.position = nn.Parameter(torch.zeros(depth, 16, 16, 1))


    def forward(self, x):
        B, C, H, W = x.shape

        weights = self.weights[:self.depth]
        rotation = self.weights[self.depth:]

        weights = rotation.mT @ weights

        x = x.permute(0, 2, 3, 1)
        x = halfabs(x)

        for w, r, b, p in zip(weights, rotation, self.bias, self.position):
            x = x + p
            x = F.linear(x, r)
            x = shift_fn(x)
            x = F.linear(x, w, b)
            x = halfabs(x)

        return x

    def extra_repr(self):
        return f'depth={self.depth}, width={self.width}'


class lip_net(nn.Module):
    def __init__(self,
                 depth: int = 32,
                 width: int = 2048,
                 num_classes: int = 1000,
                 **kwargs
                 ):
        super(lip_net, self).__init__()

        self.conv1 = Conv2d(3, width, kernel_size=3, stride=2, padding=1, input_size=32)

        self.layers = lip_layers(depth, width)
        self.lip_scale = nn.Parameter(torch.tensor(1.0))

        self.head = head(width, num_classes, use_lln=True)

    def forward(self, x):

        x = self.conv1(x - 0.5)

        x = self.layers(x)
        x = x * self.lip_scale
        x = x.norm(dim=(1, 2))

        x = self.head(x)

        return x


    def sub_lipschitz(self):
        lc = self.conv1.lipschitz()
        lc = lc * self.lip_scale.abs()
        return lc

    def set_iter(self, num_iter=10):
        self.conv1.num_lc_iter = num_iter
        return
