from jax import Array

from medium_rl.envs.proxies.proxy import GFPProxy
from medium_rl.envs.sequence_env import SequenceEnv

GFP_ALPHABET = [
    "CLS",  # Alternatively used as BOS
    "PAD",
    "EOS",
    "A",
    "C",
    "D",
    "E",
    "F",
    "G",
    "H",
    "I",
    "K",
    "L",
    "M",
    "N",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "V",
    "W",
    "Y",
]
GFP_DICT = {GFP_ALPHABET[i]: i for i in range(len(GFP_ALPHABET))}


class GFPSequence(SequenceEnv):
    name = "GFP"

    num_tokens = len(GFP_ALPHABET)
    alphabet = GFP_ALPHABET
    dict = GFP_DICT

    CLS = GFP_DICT["CLS"]
    PAD = GFP_DICT["PAD"]
    EOS = GFP_DICT["EOS"]

    def __init__(self, min_len: int, max_len: int, **kwargs):
        super().__init__(min_len, max_len)
        self.proxy = GFPProxy()

    def get_rewards(
        self,
        token_seq: Array,  # [B, T], batch of sequence of tokens
    ):
        rewards = self.proxy.evaluate(token_seq)
        return rewards
