import random
from math import prod
from typing import Any, Sequence

from logger import logger
from serde import JSON

from .integer_payload import IntegerPayload


class MultiPartPayload(IntegerPayload):
    def __init__(self, part_sizes: list[int], **kwargs):
        self.part_sizes = part_sizes
        _max = prod(part_sizes) - 1
        super().__init__(0, _max)

        def _create_instance(value: int) -> "MultiPartPayload.PayloadInstance":
            return self.__class__.PayloadInstance(value, self.part_sizes)

        self._create_instance = _create_instance

    def get_payload_description_prompt(self) -> str:
        if len(self.part_sizes) == 1:
            # A 1 part payload is just an integer.
            return super().get_payload_description_prompt()
        else:
            return f"The payload is a {self.part_sizes}-part integer."

    def get_n_random_payload_with_coverage(self, n: int) -> Sequence["PayloadInstance"]:
        payload_part_values = []
        for part_size in self.part_sizes:
            if n < part_size:
                logger.warning(
                    f"Cannot ensure each payload appears at least once as n={n} < {part_size}"
                )
                values = list(range(n))
            else:
                values = list(range(part_size)) + [
                    random.randint(0, part_size - 1) for _ in range(n - part_size)
                ]
            random.shuffle(values)
            payload_part_values.append(values)
        return [
            self.create_payload_safe_from_parts(parts)
            for parts in zip(*payload_part_values)
        ]

    def create_payload_safe_from_parts(self, parts: list[int]) -> "PayloadInstance":
        ret = 0
        multiplier = 1
        for part, size in zip(reversed(parts), reversed(self.part_sizes)):
            ret += part * multiplier
            multiplier *= size

        return self.create_payload_safe(ret)

    def get_value_augmented_payload(
        self, new_last_part_size: int
    ) -> "MultiPartPayload":
        """New payload with more values encoded in the final part."""
        return self.__class__(self.part_sizes[:-1] + [new_last_part_size])

    def get_size_augmented_payload(self, new_last_part_size: int) -> "MultiPartPayload":
        """New payload with more parts."""
        return self.__class__(self.part_sizes + [new_last_part_size])

    def serialize(self) -> JSON:
        return {
            "part_sizes": self.part_sizes,
        }

    def __str__(self) -> str:
        return "mb_" + "-".join([str(p) for p in self.part_sizes])

    def get_payload_size(self) -> int:
        return prod(self.part_sizes)

    @classmethod
    def from_json(cls, config: dict[str, Any]) -> "PayloadInstance":
        return cls(config["part_sizes"])

    class PayloadInstance(IntegerPayload.PayloadInstance):
        def __init__(self, payload: int, part_sizes: list[int]):
            super().__init__(payload)
            self.part_sizes = part_sizes

        def __repr__(self) -> str:
            return f"{self.part_sizes}-PartPayloadInstance(payload={self.payload})"

        def as_parts(self) -> list[int]:
            ret = []
            payload = self.payload
            for part in reversed(self.part_sizes):
                ret.insert(0, payload % part)
                payload //= part
            return ret
