import logging
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers.ops as xops
from dinov3.utils import named_apply, named_replace
logger = logging.getLogger("dinov3")
class LinearW24(torch.nn.Linear):
    ALGO = "largest_abs_values_greedy"
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.sparsity_enabled = False
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if not self.sparsity_enabled:
            return super().forward(input)
        input_shape = input.shape
        input = input.flatten(end_dim=-2)
        dim0 = input.shape[0]
        if dim0 % 8 != 0:
            input = F.pad(input, [0, 0, 0, -dim0 % 8])
        w_sparse = xops.sparsify24(
            self.weight,
            algo=self.ALGO,
            gradient="ste",
            backend="cusparselt",
        )
        return F.linear(input, w_sparse, self.bias,)[
            :dim0
        ].unflatten(dim=0, sizes=input_shape[:-1])
def replace_linears_with_sparse_linear(root_module: nn.Module, *, filter_fn: Callable[[str], bool]) -> nn.Module:
    total_count = 0
    def replace(module: nn.Module, name: str) -> nn.Module:
        nonlocal total_count
        if not isinstance(module, nn.Linear) or not filter_fn(name):
            return module
        assert type(module) == nn.Linear, "Subtypes not supported"
        new_module = LinearW24(
            in_features=module.in_features,
            out_features=module.out_features,
            bias=module.bias is not None,
            dtype=module.weight.dtype,
            device=module.weight.device,
        )
        new_module.weight = module.weight
        new_module.bias = module.bias
        total_count += 1
        return new_module
    out = named_replace(replace, root_module)
    assert total_count > 0, "2:4 sparsity: no layer found to sparsify"
    return out
def update_24sparsity(root_module: nn.Module, enabled: bool) -> int:
    num_modified = 0
    def maybe_apply_sparsity(module: nn.Module, name: str) -> nn.Module:
        nonlocal num_modified
        if not isinstance(module, LinearW24):
            return module
        num_modified += 1
        module.sparsity_enabled = enabled
        logger.info(f"- {'' if module.sparsity_enabled else 'de'}sparsifying {name}")
        return module
    named_apply(maybe_apply_sparsity, root_module)
    torch._dynamo.reset_code_caches()
    from torch._inductor.cudagraph_trees import reset_cudagraph_trees
    reset_cudagraph_trees()
    return num_modified
