# from typing import Set

# try:
#     import spconv.pytorch as spconv
# except:
#     import spconv as spconv

# import torch.nn as nn


# def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
#     """
#     Finds all spconv keys that need to have weight's transposed
#     """
#     found_keys: Set[str] = set()
#     for name, child in model.named_children():
#         new_prefix = f"{prefix}.{name}" if prefix != "" else name

#         if isinstance(child, spconv.conv.SparseConvolution):
#             new_prefix = f"{new_prefix}.weight"
#             found_keys.add(new_prefix)

#         found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))

#     return found_keys


# def replace_feature(out, new_features):
#     if "replace_feature" in out.__dir__():
#         # spconv 2.x behaviour
#         return out.replace_feature(new_features)
#     else:
#         out.features = new_features
#         return out
from typing import Set

import spconv
if float(spconv.__version__[2:]) >= 2.2:
    spconv.constants.SPCONV_USE_DIRECT_TABLE = False
    
try:
    import spconv.pytorch as spconv
except:
    import spconv as spconv

import torch.nn as nn


def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
    """
    Finds all spconv keys that need to have weight's transposed
    """
    found_keys: Set[str] = set()
    for name, child in model.named_children():
        new_prefix = f"{prefix}.{name}" if prefix != "" else name

        if isinstance(child, spconv.conv.SparseConvolution):
            new_prefix = f"{new_prefix}.weight"
            found_keys.add(new_prefix)

        found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))

    return found_keys


def replace_feature(out, new_features):
    if "replace_feature" in out.__dir__():
        # spconv 2.x behaviour
        return out.replace_feature(new_features)
    else:
        out.features = new_features
        return out