from abc import abstractmethod
from typing import Any, Callable, Generator

from serde import JSON, Serializable


class Payload:
    """A base class for a type of payload."""

    bit_length: int = 0
    bit_length_exact: float = 0  # A non rounded-down version of the above

    @abstractmethod
    def get_payload_description_prompt(self) -> str:
        """Describe the payload type for e.g. the encoder, decoder etc."""
        pass

    def get_payload_values(self) -> list["PayloadInstance"]:
        raise NotImplementedError(
            "By default, there is not a finite set of payload values"
        )

    @abstractmethod
    def get_random_payload(self) -> "PayloadInstance":
        pass

    @abstractmethod
    def get_n_balanced_random_payloads(self, n: int) -> list["PayloadInstance"]:
        pass

    @abstractmethod
    def create_payload_safe(self, payload: JSON) -> "PayloadInstance":
        """Create a payload instance with some type checking."""
        pass

    def random_generator(self) -> Generator["PayloadInstance", None, None]:
        """Generate random payloads.

        Yields:
            payload instances
        """
        while True:
            yield self.get_random_payload()

    @classmethod
    @abstractmethod
    def from_json(self, config: JSON) -> "Payload":
        pass

    @abstractmethod
    def serialize(self) -> dict[str, Serializable]:
        pass

    @abstractmethod
    def get_payload_size(self) -> int:
        pass


class PayloadInstance:
    """This is an instance of a payload."""

    def __init__(self, payload: Any):
        self.payload = payload

    def __eq__(self, other: Any) -> bool:
        return (
            isinstance(other, PayloadInstance) and self.payload == other.payload
        ) or (isinstance(other, type(self.payload)) and self.payload == other)

    def __str__(self) -> str:
        return str(self.payload)

    def serialize(self) -> Serializable:
        return self.payload

    def __hash__(self) -> Any:
        return self.payload

    @classmethod
    def deserialize(cls, deserialize: Callable, payload: Any) -> "PayloadInstance":
        return cls(payload)
