import torch
from torch import Tensor

from ..tensor import CompressedTensor
from ..utils import COMPRESS_FUNC_MAPPING, RECONSTRUCT_FUNC_MAPPING

LOWRANK_COMPRESS_FUNC_MAPPING = COMPRESS_FUNC_MAPPING["lowrank"]
LOWRANK_RECONSTRUCT_FUNC_MAPPING = RECONSTRUCT_FUNC_MAPPING["lowrank"]


class LowRankDecomposedTensor(CompressedTensor):
    def __new__(cls, tensor: Tensor, **kwargs):
        method = kwargs.pop("method", "rqb")
        if method not in LOWRANK_COMPRESS_FUNC_MAPPING:
            raise ValueError("Invalid value of `method`.")
        kwargs.update({"method": method})

        rank = kwargs.pop("rank")
        if isinstance(rank, int):
            pass
        elif isinstance(rank, float) and 0 < rank < 1:
            rank = int(tensor.shape[-1] * rank)
        else:
            raise TypeError("Invalid type of `rank`, must be `int` or `float` in range `(0,1)`.")
        kwargs.update({"rank": rank})

        compressed_tensor = super().__new__(cls, tensor, **kwargs)
        compressed_tensor.rank = rank
        return compressed_tensor

    @property
    def rank(self) -> int:
        return self._rank

    @rank.setter
    def rank(self, value):
        self._rank = value

    @torch.no_grad
    @staticmethod
    def compress(tensor: Tensor, **kwargs) -> tuple[Tensor, ...]:
        method = kwargs.pop("method")
        compress_func = LOWRANK_COMPRESS_FUNC_MAPPING[method]
        tensor = tensor.flatten(0, -2)  # (*l, m, n) -> (lm, n)
        factors = compress_func(tensor, **kwargs)  # (lm, n) -> ...
        return factors

    @torch.no_grad
    def reconstruct(self) -> Tensor:
        reconstruct_func = LOWRANK_RECONSTRUCT_FUNC_MAPPING[self.method]
        tensor = reconstruct_func(*self.factors)  # ... -> (lm, n)
        tensor = tensor.reshape(*self.shape)  # (lm, n) -> (*l, m, n)
        tensor.requires_grad = self.requires_grad
        return tensor

    def __repr__(self) -> str:
        data = self.reconstruct()
        data.requires_grad = self.requires_grad
        return f"{data.__repr__()[:-1]}, rank={self.rank})"

    @classmethod
    def from_factors(cls, factors: tuple[Tensor, ...], method: str = "rqb"):
        if len(factors) != 2:
            raise ValueError("Expected exactly 2 factors (Q, B) for QB decomposition.")
        
        Q, B = factors

        if Q.shape[-1] != B.shape[-2]:
            raise ValueError(
                f"Factor shape mismatch: Q.shape[-1] ({Q.shape[-1]}) != B.shape[-2] ({B.shape[-2]}). "
                "Cannot perform Q @ B."
            )

        m, r = Q.shape[-2], Q.shape[-1]
        n = B.shape[-1]
        original_shape = (m, n)
        rank = r


        if method not in LOWRANK_RECONSTRUCT_FUNC_MAPPING:
            raise ValueError(f"Reconstruction method '{method}' not registered.")

        instance = super().__new__(cls)
        instance.factors = list(factors)
        instance.method = method
        instance.rank = rank
        instance.requires_grad = False
        instance.shape = original_shape
        return instance