#%%
import torch
import math


class low_rank_linear(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        initial_rank: int = 5,
        bias: bool = True,
        device="cpu",
        # dtype=None,
    ):

        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = (
            initial_rank
            if isinstance(initial_rank, int)
            else int(min([self.in_features, self.out_features]) * initial_rank)
        )
        self.device = device
        self.maximal_rank = min([self.in_features, self.out_features])

        self.us = torch.nn.Parameter(
            torch.randn(self.out_features, self.rank, device=self.device),
            requires_grad=True,
        )
        self.vs = torch.nn.Parameter(
            torch.randn(self.in_features, self.rank, device=self.device),
            requires_grad=True,
        )
        self.s = torch.nn.Parameter(
            torch.randn(self.rank, device=self.device), requires_grad=False
        )
        if bias:
            self.bias = torch.torch.nn.Parameter(
                torch.randn(self.out_features, device=self.device)
            )
        else:
            self.bias = None
        self.weight = torch.nn.Parameter(
            torch.zeros((self.out_features, self.in_features)), requires_grad=False
        )
        self.reset_parameters()
        self.activate_lower_level()

    # @torch.no_grad()
    # def reset_parameters(self):

    #     w  = torch.randn((self.out_features,self.in_features))
    #     torch.nn.init.kaiming_uniform_(w, a=math.sqrt(5))
    #     u,s,v = torch.linalg.svd(w,full_matrices = True)
    #     self.us.data,self.s.data,self.vs.data = u[:,:self.rank],s[:self.rank],v[:,:self.rank]

    @torch.no_grad()
    def reset_parameters(self):
        torch.nn.init.uniform_(self.s, b=10)
        torch.nn.init.kaiming_uniform_(self.us, a=math.sqrt(5))
        torch.nn.init.kaiming_uniform_(self.vs, a=math.sqrt(5))
        # Orthonormalize bases
        self.us.data, _ = torch.linalg.qr(self.us.data, "reduced")
        self.vs.data, _ = torch.linalg.qr(self.vs.data, "reduced")

        if self.bias is not None:
            torch.nn.init.uniform_(self.bias, -0.1, 0.1)

    def forward(self, input):
        out = torch.nn.functional.linear(input, self.vs[:, : self.rank].T)
        out = out * self.s[: self.rank]
        out = torch.nn.functional.linear(out, self.us[:, : self.rank], bias=self.bias)
        return out

    @torch.no_grad()
    def format_weight(self):
        U, s, V = torch.linalg.svd(self.weight, full_matrices=False)
        V = V.T
        self.us.copy_(U[:, : self.rank])
        self.s.copy_(s[: self.rank])
        self.vs.copy_(V[:, : self.rank])
        # del self.weight

    def add_rank(self):
        if self.rank <= self.maximal_rank:
            self.us.data = torch.cat(
                [self.us.data, torch.randn(self.out_features, 1, device=self.device)],
                dim=1,
            )
            self.vs.data = torch.cat(
                [self.vs.data, torch.randn(self.in_features, 1, device=self.device)],
                dim=1,
            )
            self.s.data = torch.hstack(
                [self.s.data, torch.randn(1, device=self.device) ** 2]
            )
            self.rank += 1

    @torch.no_grad()
    def update_rank(self):
        while self.rank < len(self.s):
            self.add_rank()

    @torch.no_grad()
    def construct_weight_tensor(self):
        """
        just for debugging purposes, don't use it
        """
        return self.us @ torch.diag(self.s) @ (self.vs.T)

    def activate_upper_level(self):
        self.us.requires_grad = True
        self.vs.requires_grad = True
        self.s.requires_grad = True

    def activate_lower_level(self):
        self.us.requires_grad = True
        self.vs.requires_grad = True
        self.s.requires_grad = False
        self.s.grad = None

    @torch.no_grad()
    def get_hypergradient(self):
        eps = 1e-4
        self.s.grad.add_(
            torch.diag(self.us.T @ self.us.grad + self.vs.grad.T @ self.vs)
            / (self.s + eps)
        )


### test


def test():
    x = torch.randn((10, 28 * 28))
    l = low_rank_linear(28 * 28, 10, 2)
    y = torch.randn((10, 10))
    l.add_rank()
    loss = torch.nn.CrossEntropyLoss()
    loss = loss(l(x), y)
    loss.backward()
    for n, p in l.named_parameters():
        print(f"name {n},has grad {p.grad is not None},shape {p.shape}")


# test()
