import os
from pathlib import Path
from typing import Optional

import torch
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer

# source file: https://github.com/meta-llama/llama/blob/main/llama/tokenizer.py

class LlamaTokenizer:
    """Tokenizer for LLaMA."""

    def __init__(self, model_path: Path, output_type='torch') -> None:
        self.processor = SentencePieceProcessor(model_file=str(model_path))
        self.bos_id = self.processor.bos_id()
        self.eos_id = self.processor.eos_id()
        # self.pad_id = self.processor.pad_id()
        self.pad_id = self.eos_id
        self.bos_token_id = self.bos_id # compatibility for HF tokenizer
        self.eos_token_id = self.eos_id # compatibility for HF tokenizer
        self.output_type = output_type
    @property
    def vocab_size(self) -> int:
        return self.processor.vocab_size()

    def encode(
        self,
        string: str,
        bos: bool = True,
        eos: bool = True,
        max_length: int = -1,
        pad: bool = False,
        device: Optional[torch.device] = None
    ) -> torch.Tensor:
        tokens = self.processor.encode(string)
        if bos:
            tokens = [self.bos_id] + tokens
        if eos:
            tokens = tokens + [self.eos_id]
        if max_length > 0:
            tokens = tokens[:max_length]
        if pad and len(tokens) < max_length:
            tokens += [self.pad_id] * (max_length - len(tokens))

        if self.output_type == 'torch':
            return torch.tensor(tokens, dtype=torch.int64, device=device)
        else:
            return tokens

    def decode(self, tokens: torch.Tensor) -> str:
        return self.processor.decode(tokens.tolist())

    @staticmethod
    def train(input: str, destination: str, vocab_size=32000) -> None:
        model_prefix = os.path.join(destination, "tokenizer")
        SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
