import torch
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F

class LinearCompactOp(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, module):
        output = F.linear(input, weight, bias)
        if not hasattr(module, 'P') or module.P is None:
            raise RuntimeError(f"Module {module} doesn't have P")
        P = module.P
        x_compressed = input @ P

        ctx.save_for_backward(input, weight, P, x_compressed)
        ctx.has_bias = bias is not None
        ctx.module = module
        ctx.input_shape = input.shape

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, P, x_compressed = ctx.saved_tensors
        module = ctx.module
        
        input_shape = input.shape
        if len(input_shape) == 3:
            B, L, _ = input_shape
        else:
            B, _ = input_shape
            L = 1
        D_out = grad_output.shape[-1]

        grad_input = grad_output @ weight
        
        grad_bias = None
        if ctx.has_bias:
            if len(grad_output.shape) == 3:
                grad_bias = grad_output.sum(dim=(0, 1))
            else:
                grad_bias = grad_output.sum(dim=0)

        grad_y_flat = grad_output.reshape(-1, D_out)
        x_comp_flat = x_compressed.reshape(-1, P.shape[1])
        hat_G = x_comp_flat.t() @ grad_y_flat 

        module.hat_G = hat_G
        fake_grad = (P @ hat_G).t()
        
        return grad_input, fake_grad, grad_bias, None


class LinearCompactClass(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
        self.activate = False
        self.hat_G = None

    def forward(self, input):
        if self.activate and self.training and hasattr(self, 'P') and self.P is not None:
            return LinearCompactOp.apply(input, self.weight, self.bias, self)
        return F.linear(input, self.weight, self.bias)


def wrap_linear_compression_layer(linear_layer: nn.Linear):
    wrapped = LinearCompactClass(
        in_features=linear_layer.in_features,
        out_features=linear_layer.out_features,
        bias=linear_layer.bias is not None,
        device=linear_layer.weight.device,
        dtype=linear_layer.weight.dtype,
    )

    with torch.no_grad():
        wrapped.weight.copy_(linear_layer.weight)
        if linear_layer.bias is not None:
            wrapped.bias.copy_(linear_layer.bias)

    if hasattr(linear_layer, 'P') and linear_layer.P is not None:
        wrapped.register_buffer('P', linear_layer.P.detach().clone())
        wrapped.activate = True
    else:
        wrapped.activate = False

    wrapped.hat_G = None
    return wrapped


def register_filter(model, modules_compressed):
    wrapped_modules = {}
    
    print(f"\n[register_filter] Wrapping {len(modules_compressed)} modules...")
    
    for name, module in list(model.named_modules()):
        if name in modules_compressed and isinstance(module, nn.Linear):
            if "." in name:
                parent_name, attr_name = name.rsplit(".", 1)
                parent = model
                for part in parent_name.split("."):
                    parent = getattr(parent, part)
            else:
                parent = model
                attr_name = name
            print(f"  Wrapping: {name}")
            wrapped = wrap_linear_compression_layer(module)
            setattr(parent, attr_name, wrapped)
            wrapped_modules[name] = wrapped
    print(f"[register_filter] Completed wrapping {len(wrapped_modules)} modules\n")
    
    return model, wrapped_modules