import torch
from torch import Tensor

from ..compressed import CompressedTensor
class LinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        input: Tensor,
        weight: Tensor,
        bias: Tensor | None = None,
        compress_kwargs: dict | None = None,
    ) -> Tensor:
        return torch._C._nn.linear(input, weight, bias)

    @staticmethod
    def setup_context(ctx, inputs, output):
        input, weight, bias, compress_kwargs = inputs
        ctx.device_type = input.device.type        
        ctx.autocast_kwargs = {
            "dtype": torch.get_autocast_dtype(ctx.device_type),
            "enabled": torch.is_autocast_enabled(ctx.device_type),
            "cache_enabled": torch.is_autocast_cache_enabled(),
        }
        
        if compress_kwargs is not None:
            use_optimizer_compress = compress_kwargs.pop("use_optimizer_compress", False)
            use_gradient_compress = compress_kwargs.pop("use_gradient_compress", True)
            if use_optimizer_compress and use_gradient_compress:
                compress_kwargs["method"] = "rp"
            ctx.save_for_backward(CompressedTensor(input, **compress_kwargs), weight, bias)
            compress_kwargs["use_optimizer_compress"] = use_optimizer_compress
            compress_kwargs["use_gradient_compress"] = use_gradient_compress
            ctx.use_gradient_compress = use_gradient_compress
            ctx.use_optimizer_compress = use_optimizer_compress
        else:
            ctx.save_for_backward(input, weight, bias)

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> tuple[Tensor | None, ...]:
        input, weight, bias = ctx.saved_tensors
        if isinstance(input, CompressedTensor) and ctx.use_gradient_compress == True:
            Q , B = input.factors
        elif isinstance(input, CompressedTensor) and ctx.use_gradient_compress == False:    
            input = input.reconstruct()

        with torch.autocast(ctx.device_type, **ctx.autocast_kwargs):
            grad_output_2d = grad_output.reshape(-1, grad_output.shape[-1])
            if isinstance(input, Tensor) and not isinstance(input, CompressedTensor):
                input_2d = input.reshape(-1, input.shape[-1])

            if ctx.needs_input_grad[0]:
                grad_input = grad_output @ weight
            else:
                grad_input = None

            if ctx.needs_input_grad[1]:
                if isinstance(input, CompressedTensor):
                    grad_weight_Q = grad_output_2d.T @ Q
                    weight._lowrank_grad = (grad_weight_Q, B)
                    if ctx.use_optimizer_compress == True:
                        weight.use_optimizer_compress = True
                    else:
                        weight.use_optimizer_compress = False
                elif isinstance(input, Tensor):
                    grad_weight = grad_output_2d.T @ input_2d
            else:
                grad_weight = None

            if bias is not None and ctx.needs_input_grad[2]:
                grad_bias = grad_output_2d.sum(dim=0)
            else:
                grad_bias = None
        if isinstance(input, CompressedTensor):
            return grad_input, None, grad_bias, None
        else:
            return grad_input, grad_weight, grad_bias, None
