from torch import Size, Tensor
import torch

class LowRankGradientTensor(Tensor):
    def __new__(cls, factors, **kwargs):
        if len(factors) != 2:
            raise ValueError("Expected exactly 2 factors (Q, B) for QB decomposition.")
        Q, B = factors
        if Q.shape[-1] != B.shape[-2]:
            raise ValueError(
                f"Factor shape mismatch: Q.shape[-1] ({Q.shape[-1]}) != B.shape[-2] ({B.shape[-2]}). "
                "Cannot perform Q @ B."
            )
        m, r = Q.shape[-2], Q.shape[-1]
        n = B.shape[-1]
        original_shape = torch.Size((m, n))
        rank = r

        instance = super().__new__(cls)
        instance.factors = factors
        instance.rank = rank
        instance.requires_grad = False
        instance.shape = original_shape
        return instance
    

    @torch.no_grad
    def reconstruct(self) -> Tensor:
        Q , B = self.factors
        gradient_tensor = Q @ B
        gradient_tensor.requires_grad = False
        return gradient_tensor
    
    @property
    def rank(self) -> int:
        return self._rank

    @rank.setter
    def rank(self, value):
        self._rank = value

    @property
    def factors(self) -> list:
        return self._factors

    @factors.setter
    def factors(self, value):
        self._factors = value

    @property
    def requires_grad(self) -> bool:
        return self._fake_requires_grad

    @requires_grad.setter
    def requires_grad(self, value):
        self._fake_requires_grad = value

    @property
    def shape(self) -> Size:
        return self._original_shape

    @shape.setter
    def shape(self, value):
        self._original_shape = value
