from jax import Array

from medium_rl.envs.proxies.proxy import AMPProxy
from medium_rl.envs.sequence_env import SequenceEnv

AMP_ALPHABET = [
    "CLS",  # Alternatively used as BOS
    "PAD",
    "EOS",
    "A",
    "C",
    "D",
    "E",
    "F",
    "G",
    "H",
    "I",
    "K",
    "L",
    "M",
    "N",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "V",
    "W",
    "Y",
]
AMP_DICT = {AMP_ALPHABET[i]: i for i in range(len(AMP_ALPHABET))}


class AMPSequence(SequenceEnv):
    name = "AMP"

    num_tokens = len(AMP_ALPHABET)
    alphabet = AMP_ALPHABET
    dict = AMP_DICT

    CLS = AMP_DICT["CLS"]
    PAD = AMP_DICT["PAD"]
    EOS = AMP_DICT["EOS"]

    def __init__(self, min_len: int, max_len: int, **kwargs):
        super().__init__(min_len, max_len)
        self.proxy = AMPProxy()

    def get_rewards(
        self,
        token_seq: Array,  # [B, T], batch of sequence of tokens
    ):
        rewards = self.proxy.evaluate(token_seq)
        return rewards
