# SPDX-License-Identifier: MIT
from __future__ import annotations
import torch
import torch.nn as nn


def orthogonal_init_global_(linear: nn.Linear) -> None:
    with torch.no_grad():
        w = linear.weight
        n = w.size(0)
        a = torch.randn(n, n, device=w.device, dtype=w.dtype)
        q, r = torch.linalg.qr(a)
        sign = torch.sign(torch.diag(r))
        sign[sign == 0] = 1
        q = q @ torch.diag(sign)
        w.copy_(q)


class FunctionalReversalNet(nn.Module):
    def __init__(self, max_tokens: int, d_model: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        n = max_tokens * d_model
        self.max_tokens = max_tokens
        self.d_model = d_model
        self.fc = nn.Linear(n, n, bias=False)
        orthogonal_init_global_(self.fc)
        self.to(dtype=dtype)

    def forward(self, x_seq: torch.Tensor) -> torch.Tensor:
        x = x_seq.reshape(x_seq.size(0), -1)
        return self.fc(x)


def reverse_fc_layer(output: torch.Tensor, fc_layer: nn.Linear) -> torch.Tensor:
    with torch.no_grad():
        orig_dtype = output.dtype
        w = fc_layer.weight
        b = fc_layer.bias
        w32 = w.float()
        b32 = b.float() if b is not None else None
        y32 = output.float()
        w_pinv32 = torch.linalg.pinv(w32)
        x32 = torch.matmul((y32 - (b32 if b32 is not None else 0.0)), w_pinv32.T)
        return x32.to(dtype=orig_dtype)