import torch.nn as nn
from abc import ABC, abstractmethod
from transformers.pytorch_utils import Conv1D

Quantisable = nn.Linear | nn.Conv2d | Conv1D


def quantisable(m: nn.Module) -> bool:
    if not isinstance(m, nn.Module):
        raise ValueError(
            f"Tried to call quantisable on {m}, which is of type {type(m)}, not torch.nn.Module."
        )
    if hasattr(m, "quantisable"):
        return m.quantisable
    return isinstance(m, Quantisable)


class CompressedTensor(ABC):
    @property
    @abstractmethod
    def bpp(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def x(self):
        raise NotImplementedError


__all__ = ["Quantisable", "quantisable"]
