from jax import Array

from medium_rl.envs.proxies.proxy import UTRProxy
from medium_rl.envs.sequence_env import SequenceEnv

UTR_ALPHABET = [
    "CLS",
    "PAD",
    "EOS",
    "A",
    "C",
    "G",
    "T",
]
UTR_DICT = {UTR_ALPHABET[i]: i for i in range(len(UTR_ALPHABET))}


class UTRSequence(SequenceEnv):
    name = "UTR"

    num_tokens = len(UTR_ALPHABET)
    alphabet = UTR_ALPHABET
    dict = UTR_DICT

    CLS = UTR_DICT["CLS"]
    PAD = UTR_DICT["PAD"]
    EOS = UTR_DICT["EOS"]

    def __init__(self, min_len: int, max_len: int, **kwargs):
        super().__init__(min_len, max_len)
        self.proxy = UTRProxy()

    def get_rewards(
        self,
        token_seq: Array,  # [B, T], batch of sequence of tokens
    ):
        rewards = self.proxy.evaluate(token_seq)
        return rewards
