import random
from math import floor, log2
from typing import Any

from serde import JSON

from .payload import Payload, PayloadInstance
from .payload import PayloadInstance as BasePayloadInstance
from .payload_utils import get_balanced_integer_set


class IntegerPayload(Payload):
    __NAME = "integer"

    def __init__(self, min_value: int, max_value: int, **kwargs):
        self.min = min_value
        self.max = max_value
        self.bit_length_exact = log2(max_value - min_value + 1)
        self.bit_length = floor(self.bit_length_exact)

    def get_payload_description_prompt(self) -> str:
        return f"The payload is an integer between {self.min} and {self.max} inclusive. It cannot be a decimal number e.g. 0.5, it will be a whole number e.g. 0, 1."

    def _create_instance(self, val: int) -> "IntegerPayload.PayloadInstance":
        return IntegerPayload.PayloadInstance(val)

    def get_random_payload(self) -> "PayloadInstance":
        return self._create_instance(random.randint(self.min, self.max))

    def get_n_balanced_random_payloads(self, n: int) -> list["PayloadInstance"]:
        return [
            self.create_payload_safe(i)
            for i in get_balanced_integer_set(self.min, self.max, n)
        ]

    def create_payload_safe(self, payload: Any) -> "PayloadInstance":
        payload_as_int = int(payload)
        if not (self.min <= payload_as_int <= self.max):
            raise ValueError(
                f"Payload must be an integer between {self.min} and {self.max} inclusive"
            )

        return self._create_instance(payload_as_int)

    def get_payload_values(self) -> list["PayloadInstance"]:
        return [self._create_instance(i) for i in range(self.min, self.max + 1)]

    def serialize(self) -> JSON:
        return {
            "bit_length": self.bit_length,
            "min_value": self.min,
            "max_value": self.max,
        }

    def __str__(self) -> str:
        return f"int_{self.min}-{self.max}"

    def get_payload_size(self) -> int:
        return self.max - self.min + 1

    @classmethod
    def from_json(cls, config: dict[str, Any]) -> "IntegerPayload":
        return cls(config["min_value"], config["max_value"])

    class PayloadInstance(BasePayloadInstance):
        def __init__(self, payload: int):
            assert isinstance(payload, int), "Payload must be an integer"
            super().__init__(payload)

        def __eq__(self, other: Any) -> bool:
            return (
                isinstance(other, self.__class__) and self.payload == other.payload
            ) or (isinstance(other, int) and self.payload == other)

        def __repr__(self) -> str:
            return f"IntegerPayloadInstance(payload={self.payload})"
