import itertools

from jax import Array

from medium_rl.envs.proxies.proxy import BitSeqProxy
from medium_rl.envs.sequence_env import SequenceEnv


class BitSequence(SequenceEnv):
    name = "BitSequence"

    def __init__(
        self,
        min_len: int,
        max_len: int,
        substring_len: int,
        num_modes: int,
        mode_seed: int = 0,
        **kwargs,
    ):
        """
        max_len: The 'n' described in the paper (i.e. the length of completed bit sequences)
        substring_len: Length of added bit substrings (i.e. the 'k' described in the paper)
        See: https://arxiv.org/pdf/2201.13259
        """
        assert (max_len % substring_len) == 0, f"{max_len} must be divisible by {substring_len}"

        self.max_len = max_len
        self.substring_len = substring_len
        self.num_modes = num_modes
        self.mode_seed = mode_seed

        self.alphabet = [
            "CLS",  # Alternatively used as BOS
            "PAD",
            "EOS",
        ] + ["".join(bits) for bits in itertools.product("01", repeat=substring_len)]
        self.num_tokens = len(self.alphabet)
        self.dict = {self.alphabet[i]: i for i in range(len(self.alphabet))}

        self.CLS = self.dict["CLS"]
        self.PAD = self.dict["PAD"]
        self.EOS = self.dict["EOS"]

        super().__init__(min_len // substring_len + 2, max_len // substring_len + 2)  # +2 for BOS/EOS
        self.proxy = BitSeqProxy(num_modes, max_len, self.alphabet, mode_seed)

    def get_rewards(
        self,
        token_seq: Array,  # [B, T], batch of sequence of tokens
    ):
        rewards = self.proxy.evaluate(token_seq)
        return rewards
