import json
from transformers import AutoTokenizer

def get_tokenizer(path):
    return AutoTokenizer.from_pretrained(path, trust_remote_code=True)

def encode_chat(tokenizer, messages, chat_template_args={}, completions=False):
    if completions:
        return tokenizer.encode(messages[-1]['content'], add_special_tokens=False)
    return tokenizer.encode(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, **chat_template_args), add_special_tokens=False)

def decode_chat(tokenizer, out_tokens):
    return tokenizer.decode(out_tokens)

def read_json(path):
    if path is not None:
        with open(path, "r") as f:
            data = json.load(f)
        return data
    return {}

def postprocess_base(text):
    return text

def postprocess_gptoss(text):
    final_message = text.split("<|channel|>final<|message|>")[-1]
    if "<|end|>" in final_message:
        final_message = final_message.split("<|end|>")[0]
    if "<|return|>" in final_message:
        final_message = final_message.split("<|return|>")[0]
    if "<|channel|>" in final_message:
        final_message = final_message.split("<|channel|>")[0]
    return final_message