from abc import abstractmethod

from torch import Size, Tensor
import torch
from . import config
from .utils import COMPRESS_FUNC_MAPPING, get_compress_cache, get_compress_processor
def _debug_print_factors(factors, original_tensor):
    if len(factors) < 2:
        print("[DEBUG] factors does not contain Q and B.")
        return

    q, b = factors[0], factors[1]
    reconstructed = q @ b
    original_reshaped = original_tensor.reshape(-1, original_tensor.shape[-1])  # [B*L*N, N]

    global_rank = torch.linalg.matrix_rank(original_reshaped).item()
    B, L, N, _ = original_tensor.shape
    local_matrices = original_tensor.view(-1, N, N)  # shape: [B*L, N, N]
    
    ranks = []
    for i in range(local_matrices.size(0)):
        rank_i = torch.linalg.matrix_rank(local_matrices[i]).item()
        ranks.append(rank_i)
    
    avg_local_rank = sum(ranks) / len(ranks)

    error = (reconstructed - original_reshaped).abs().mean().item()

    print(
        f"[DEBUG] QB Decomposition:\n"
        f"  Original tensor shape: {original_tensor.shape}\n"
        f"  Reshaped matrix shape (global): {original_reshaped.shape}\n"
        f"  Global matrix rank: {global_rank}\n"
        f"  Number of local N×N matrices: {len(ranks)}\n"
        f"  Average local rank (over {len(ranks)} matrices): {avg_local_rank:.2f}\n"
        f"  Q shape: {q.shape}, B shape: {b.shape}\n"
        f"  Mean absolute reconstruction error: {error:.6e}"
    )

class CompressedTensor(Tensor):
    """
    This class implements **CompressedTensor** as a **Tensor** subclass.

    Args:
        tensor (Tensor):
            &#45; the original tensor.
        method (str, optional):
            &#45; the compressing method.
            Default: `'rqb'`.
        **kwargs:
            &#45; additional keyword arguments used by compression.

    Returns:
        CompressedTensor:
            *-* return as a subclass type (e.g. `SingularValueDecomposedTensor`).
    """

    def __new__(cls, tensor: Tensor, **kwargs):
        # The `CompressedTensor` class will not be instantiated directly;
        # it will return as a subclass type based on the argument `method`.
        # print(tensor.shape)
        if cls is CompressedTensor:
            method = kwargs.get("method", "rqb")
            if method in COMPRESS_FUNC_MAPPING["lowrank"]:
                from .lowrank import LowRankDecomposedTensor
                return LowRankDecomposedTensor(tensor, **kwargs)
            else:
                raise ValueError("Invalid value of `method`.")

        else:
            if not isinstance(tensor, Tensor):
                raise TypeError("Invalid type of `tensor`, must be `torch.Tensor`.")

            kwargs_key = tuple(sorted(kwargs.items()))
            tensor_key = tensor

            compress_cache = get_compress_cache()

            if config.CACHE_COMPRESS and tensor_key in compress_cache[kwargs_key]:
                factors = compress_cache[kwargs_key][tensor_key]

            else:
                if isinstance(tensor, CompressedTensor):

                    tensor = tensor.reconstruct()

                compress_processor = get_compress_processor()

                if config.ASYNC_COMPRESS and compress_processor.running:
                    factors = []
                    compress_processor.submit(
                        cls.compress,
                        args=(tensor,),
                        kwargs=kwargs,
                        outputs=factors,
                    )
                else:
                    factors = [*cls.compress(tensor, **kwargs)]

                if config.CACHE_COMPRESS:
                    compress_cache[kwargs_key][tensor_key] = factors  # Use orginal `tensor_key` instead of `tensor`, which may be transformed.

            # _debug_print_factors(factors,tensor)
            compressed_tensor = super().__new__(cls)
            compressed_tensor.factors = factors
            compressed_tensor.method = kwargs["method"]
            compressed_tensor.requires_grad = tensor_key.requires_grad
            compressed_tensor.shape = tensor_key.shape
            return compressed_tensor

    @staticmethod
    @abstractmethod
    def compress(tensor: Tensor, **kwargs) -> tuple[Tensor, ...]:
        pass

    @abstractmethod
    def reconstruct(self) -> Tensor:
        pass

    @property
    def factors(self) -> list:
        return self._factors

    @factors.setter
    def factors(self, value):
        self._factors = value

    @property
    def method(self) -> str:
        return self._method

    @method.setter
    def method(self, value):
        self._method = value

    @property
    def requires_grad(self) -> bool:
        return self._fake_requires_grad

    @requires_grad.setter
    def requires_grad(self, value):
        self._fake_requires_grad = value

    @property
    def shape(self) -> Size:
        return self._original_shape

    @shape.setter
    def shape(self, value):
        self._original_shape = value
