import torch

from megatron.training.tokenizer.tokenizer import _HuggingFaceTokenizer
from megatron.training.global_vars import _ensure_var_is_not_initialized, _ensure_var_is_initialized

_GLOBAL_RM_TOKENIZERS = None
_GLOBAL_ACTOR_TOKENIZER = None


def set_global_variables(args, build_tokenizer=True):
    assert args is not None
    if build_tokenizer:
        _build_rm_tokenizers(args)
        _build_actor_tokenizer(args)


def _build_rm_tokenizers(args):
    global _GLOBAL_RM_TOKENIZERS
    _ensure_var_is_not_initialized(_GLOBAL_RM_TOKENIZERS, 'rm_tokenizers')

    if args.rm_tokenizer_models is None:
        _GLOBAL_RM_TOKENIZERS = []
        return _GLOBAL_RM_TOKENIZERS

                                   
    assert args.tokenizer_type == "HuggingFaceTokenizer"

    _GLOBAL_RM_TOKENIZERS = [_HuggingFaceTokenizer(p) for p in args.rm_tokenizer_models]
    return _GLOBAL_RM_TOKENIZERS


def get_rm_tokenizer(rm_idx=0):
    _ensure_var_is_initialized(_GLOBAL_RM_TOKENIZERS, 'rm_tokenizers')
    return _GLOBAL_RM_TOKENIZERS[rm_idx]


def _build_actor_tokenizer(args):
    global _GLOBAL_ACTOR_TOKENIZER
    _ensure_var_is_not_initialized(_GLOBAL_ACTOR_TOKENIZER, 'actor_tokenizers')

    if args.actor_tokenizer_model is None:
        _GLOBAL_ACTOR_TOKENIZER = None
        return _GLOBAL_ACTOR_TOKENIZER

                                   
    assert args.tokenizer_type == "HuggingFaceTokenizer"

    _GLOBAL_ACTOR_TOKENIZER = _HuggingFaceTokenizer(args.actor_tokenizer_model)

    return _GLOBAL_ACTOR_TOKENIZER


def get_actor_tokenizer():
    _ensure_var_is_initialized(_GLOBAL_ACTOR_TOKENIZER, 'actor_tokenizers')
    return _GLOBAL_ACTOR_TOKENIZER
