import torch
import numpy as np
import torch
from transformers import LlamaTokenizer

class ContinuousScalarTokenizer:
    ''' Continous scalar tokenizer, Adapted from DB1's implementation '''
    def __init__(
        self, ver:str='v1', num_continuous_bin:int=1024, mu:float=100.0, M:float=256.0
    ):
        self.num_continuous_bin = num_continuous_bin
        self.ver = ver
        self.mu = mu
        self.M = M

    def discretize(self, x, is_action: bool):
        """
        Discretization of float scalars, if is_action then don't need mu-law scaling.
        """
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x.copy()).float()
        assert is_action == False, 'The action of the COP is an integer'
        
        x_mu_lawed = (
            torch.sign(x)
            * torch.log(torch.abs(x) * self.mu + 1.0)
            / torch.log(torch.tensor(self.mu * self.M + 1.0))
        )

        if self.ver == 'v1':    
            # 在 [-1, 1] 范围内进行离散化
            x = torch.clamp(x_mu_lawed, -1, 1)
            x = ((x + 1) / 2 * self.num_continuous_bin).int()
            x = torch.clamp(x, 0, self.num_continuous_bin - 1).int()
        elif self.ver == 'v2':
            # 在 [0, 1] 范围内进行离散化
            x = x_mu_lawed
            assert 0-1e-2 <= x.min() <= x.max() <= 1+1e-2
            x = (x * self.num_continuous_bin).int()
            x = torch.clamp(x, 0, self.num_continuous_bin - 1).int()
        else:
            raise NotImplementedError
        return x

    def decode(self, x, is_action: bool):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        if x.max() >= self.num_continuous_bin or x.min() < 0:
            print(
                "Warning of exceeded range of discrete number to recontruct, "
                "by default values will be cliped, min: {}, max:{}".format(
                    x.min(), x.max()
                )
            )
            x = np.clip(x, 0, self.num_continuous_bin - 1)

        x = (x.float() / self.num_continuous_bin) * 2 - 1
        if not is_action:
            x = torch.sign(x) * ((1 + self.M * self.mu) ** torch.abs(x) - 1) / self.mu

        return x
    
'''
class MDPLlamaTokenizer(LlamaTokenizer):
    def __init__(self, **kwargs):
        self.name = "MDPTokenizer"
        self.version = kwargs['tokenizer_ver']
        self.num_discrete_values = kwargs['num_discrete_values']
        self.num_continous_values = kwargs['num_continous_values']

        self.special_tokens = {
            "<|>": self.num_discrete_values + self.num_continous_values,
            "<X>": self.num_discrete_values + self.num_continous_values + 1
        }

        super().__init__(
            vocab_file=None,
            bos_token=None,
            eos_token=None,
            unk_token=None,
            pad_token=None,
            add_bos_token=False,
            add_eos_token=False,
            add_prefix_space=False,
            legacy=False,
            **kwargs,
        )

    def get_spm_processor(self, from_slow=False):
        pass

    @property
    def vocab_size(self):
        """Returns vocab size"""
        return self.num_discrete_values + self.num_continous_values + len(self.special_tokens)
    
    def get_vocab(self):
        """Returns vocab as a dict"""
        vocab = {i: i for i in range(self.num_discrete_values + self.num_continous_values)}
        vocab.update(self.special_tokens)
        return vocab
    
    def _tokenize(self, text, **kwargs):
        return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)

if __name__ == "__main__":
    tokenizer_config = {
        'seq_length': 1024,                         # GPT 输入 token 序列长度
        'num_discrete_values': 1024,                # 离散值对应的 token 个数
        'num_continous_values': 1024,               # 连续值对应的 token 个数
        'tokenizer_ver': 'v1',                      # v1 在 [-1,1] 区间离散化，v2 在 [0,1] 区间离散化
        'discretize_mu': 100,                       # ContinuousScalarTokenizer 的 mu-law 参数
        'discretize_M': 256,                        # ContinuousScalarTokenizer 的 mu-law 参数
        'use_prefix': False,                        # 构造样本时是否前置 prefix 序列  
        'use_prompt': True,                         # 构造样本时是否前置 prompt 序列  
        'prompt_prob': 0.25,                        # 若要设置 prompt 序列，设置的概率
        'prompt_ratio': 0.5,                        # prompt 占完整序列长度的比例
        'prompt_at_final_transition_prob': 0.5      # 使用 end of an episode 作为 prompt 的概率
    }

    my_tokenizer = MDPLlamaTokenizer(**tokenizer_config)
'''