import importlib.util

import torch


def is_flash_attn_available():
    return importlib.util.find_spec("flash_attn") is not None


def is_flash_attn_3_available():
    return importlib.util.find_spec("flash_attn_interface") is not None


def is_torch_version(operator: str, version: str):
    from packaging import version as pversion

    torch_version = pversion.parse(torch.__version__)
    target_version = pversion.parse(version)

    # print(f"torch_version: {torch_version}, target: torch{operator}{target_version}")
    if operator == ">":
        return torch_version > target_version
    elif operator == ">=":
        return torch_version >= target_version
    elif operator == "==":
        return torch_version == target_version
    elif operator == "<=":
        return torch_version <= target_version
    elif operator == "<":
        return torch_version < target_version
    return False
