"""
action_tokenizer.py

Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
"""

from typing import List, Union

import numpy as np
from transformers import PreTrainedTokenizer


class ActionTokenizer:
    def __init__(
            self, tokenizer: PreTrainedTokenizer, bins: int = 256, min_action: int = -1, max_action: int = 1
    ) -> None:
        """
        Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens.

        NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens*
                 appear at the end of the vocabulary!

        :param tokenizer: Base LLM/VLM tokenizer to extend.
        :param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy.
        :param min_action: Minimum action value (for clipping, setting lower bound on bin interval).
        :param max_action: Maximum action value (for clipping, setting upper bound on bin interval).
        """
        self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action

        # Create Uniform Bins + Compute Bin Centers
        self.bins = np.linspace(min_action, max_action, self.n_bins)
        self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0

        # Add action tokens to the tokenizer
        token_list = [f'<ACTION_{i}>' for i in range(1, self.n_bins + 1)]
        self.token_id_to_bin_id = {self.tokenizer.convert_tokens_to_ids(token): i for i, token in enumerate(token_list)}
        self.token_ids = list(self.token_id_to_bin_id.keys())

    def __call__(self, action: np.ndarray) -> Union[str, List[str]]:
        """Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:])."""
        action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
        discretized_action = np.digitize(action, self.bins)

        # Handle single element vs. batch
        if len(discretized_action.shape) == 1:
            tokens = ''.join([f'<ACTION_{i}>' for i in discretized_action])
            return tokens
        else:
            tokens = [''.join(f'<ACTION_{i}>' for i in row) for row in discretized_action]
            return tokens

    def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray:
        """
        Returns continuous actions for discrete action token IDs.

        NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the
                 digitization returns bin indices between [1, # bins], inclusive, when there are actually only
                 (# bins - 1) bin intervals.

                 Therefore, if the digitization returns the last possible index, we map this to the last bin interval.

        EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns
                    indices between [0, 255]. There is still one index (i==255) that would cause an out-of-bounds
                    error if used to index into self._bin_centers. Therefore, if i==255, we subtract 1 from it so
                    that it just becomes the index of the last bin center. We implement this simply via clipping
                    between [0, 255 - 1].
        """
        discretized_actions = np.array([
            self.token_id_to_bin_id[token_id] if token_id in self.token_ids else 0
            for token_id in action_token_ids.flatten()
        ])
        discretized_actions = np.clip(discretized_actions, a_min=0, a_max=self.bin_centers.shape[0] - 1)
        discretized_actions = discretized_actions.reshape(action_token_ids.shape)

        return self.bin_centers[discretized_actions]

    @property
    def vocab_size(self) -> int:
        return self.n_bins
