import os

from transformers import AutoTokenizer
from functools import partial
import torch


def get_tokenizer_chinese(context_length, size):
    # print('load tokenizer size:{}, context_length:{}'.format(size, context_length))
    assert size in ['base', 'large', 'huge']
    script_dir = os.path.dirname(os.path.abspath(__file__))

    if size == 'base':
        model_path = os.path.join(script_dir, 'zh_tokenizers/base')
    elif size == 'large':
        model_path = os.path.join(script_dir, 'zh_tokenizers/large')
    else:
        model_path = os.path.join(script_dir, 'zh_tokenizers/huge')

    base_tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        padding_side='right',
    )
    vocab_size = base_tokenizer.vocab_size
    vocab_size += 10
    partial_tokenizer = partial(base_tokenizer, padding='max_length',
                                truncation=True, return_tensors='pt', return_attention_mask=False,
                                max_length=context_length)

    def wrapper_tokenizer(x):
        x = partial_tokenizer(x)
        x = x['input_ids']
        x = torch.LongTensor(x)
        return x

    return wrapper_tokenizer, vocab_size
