# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import json
from pathlib import Path
from typing import Optional, Union, List
import os

import torch


class Tokenizer:
    def __init__(self, checkpoint_dir: Union[Path, str], use_delete_token: bool=False) -> None:
        checkpoint_dir = Path(checkpoint_dir)
        if not checkpoint_dir.exists():
            try:
                from transformers import AutoTokenizer

                temp_processor = AutoTokenizer.from_pretrained(checkpoint_dir)
                checkpoint_dir = Path(os.path.dirname(temp_processor.vocab_file))
            except OSError:
                raise NotADirectoryError(f"The checkpoint directory does not exist: {str(checkpoint_dir)}")

        self.bos_id = None
        self.eos_id = None
        self.pad_id = None

        if (checkpoint_dir / "tokenizer.json").is_file():
            from transformers import AutoTokenizer

            self.processor = AutoTokenizer.from_pretrained(
                str(checkpoint_dir), add_bos_token=False, add_eos_token=False
            )

            self.backend = "huggingface"

            if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file():
                with open(special_tokens_path) as fp:
                    config = json.load(fp)
                self.bos_id = self.processor.bos_token_id
                self.eos_id = self.processor.eos_token_id
                self.pad_id = self.processor.pad_token_id
            if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file():
                with open(special_tokens_path) as fp:
                    config = json.load(fp)
                if self.bos_id is None:
                    self.bos_id = config.get("bos_token_id")
                if self.eos_id is None:
                    self.eos_id = config.get("eos_token_id")
                if self.pad_id is None:
                    self.pad_id = config.get("pad_token_id")  # idk if this will always work
        elif "open_llama" in str(checkpoint_dir):
            from transformers import LlamaTokenizer

            self.processor = LlamaTokenizer.from_pretrained(
                str(checkpoint_dir), add_bos_token=False, add_eos_token=False
            )

            self.backend = "huggingface"

            if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file():
                with open(special_tokens_path) as fp:
                    config = json.load(fp)
                self.bos_id = self.processor.bos_token_id
                self.eos_id = self.processor.eos_token_id
                self.pad_id = self.processor.pad_token_id
            if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file():
                with open(special_tokens_path) as fp:
                    config = json.load(fp)
                if self.bos_id is None:
                    self.bos_id = config.get("bos_token_id")
                if self.eos_id is None:
                    self.eos_id = config.get("eos_token_id")
                if self.pad_id is None:
                    self.pad_id = config.get("pad_token_id")  # idk if this will always work
        else:
            raise NotImplementedError
        
        ### Adding a delete token
        self.delete_id = None
        if use_delete_token:
            if "open_llama" in str(checkpoint_dir):
                raise NotImplementedError
            
            self.processor.add_special_tokens({"additional_special_tokens": ["[DEL]"]}, replace_additional_special_tokens=False)
            print(f"Added the delete token, special tokens are now: {self.processor.all_special_tokens}, with ids: {self.processor.all_special_ids}")
            self.delete_id = self.processor.convert_tokens_to_ids("[DEL]")

    @property
    def vocab_size(self) -> int:
        return self.processor.vocab_size
    
    def __len__(self) -> int:
        # https://stackoverflow.com/questions/67412925/what-is-the-difference-between-lentokenizer-and-tokenizer-vocab-size#:~:text=Size%20of%20the%20base%20vocabulary%20(without%20the%20added%20tokens).&text=So%20you%20can%20clearly%20see,plus%20the%20len(added_tokens_encoder)%20.
        return len(self.processor)

    def add_seperator_token(self):
        self.processor.add_special_tokens({"additional_special_tokens": ["[SEP]"]}, replace_additional_special_tokens=False)
        print(f"Added the seperator token, special tokens are now: {self.processor.all_special_tokens}, with ids: {self.processor.all_special_ids}")
        self.sep_id = self.processor.convert_tokens_to_ids("[SEP]")

    def encode(
        self,
        string: str,
        device: Optional[torch.device] = None,
        bos: Optional[bool] = None,
        eos: bool = False,
        max_length: int = -1,
    ) -> torch.Tensor:
        tokens = self.processor.encode(string)

        if bos:
            bos_id = self.bos_id
            if bos_id is None:
                raise NotImplementedError("This tokenizer does not have a defined a bos token")
            tokens = [bos_id] + tokens
        if eos:
            tokens = tokens + [self.eos_id]
        if max_length > 0:
            tokens = tokens[:max_length]
        return torch.tensor(tokens, dtype=torch.int, device=device)

    def remove_delete_tokens_list(self, tokens: List) -> List:
        """Remove the deleted tokens from the list"""
        stack = []
        for i in range(len(tokens)):
            if tokens[i] == self.delete_id:
                stack.pop() # pop the last element and don't add the delete token
            else:
                stack.append(tokens[i])
        tokens = stack

        return stack

    def decode(self, tensor: torch.Tensor, skip_special_tokens: bool = False, remove_deleted_tokens: bool = False) -> str:
        # remove delete is offered seperately so we can still see BOS and EOS without deletes
        tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()

        if (self.delete_id is not None) and (skip_special_tokens or remove_deleted_tokens): # preprocess to remove the deleted tokens
            tokens = self.remove_delete_tokens_list(tokens)

        return self.processor.decode(tokens, skip_special_tokens=skip_special_tokens)

    def remove_delete_tokens(self, input_tensor_or_list):
        """Remove the deleted tokens from the input"""
        if self.delete_id is None:
            raise AttributeError("no delete token present but the remove deletes function is being called")

        if isinstance(input_tensor_or_list, list):
            return self.remove_delete_tokens_list(input_tensor_or_list)
            
        elif isinstance(input_tensor_or_list, torch.Tensor):
            device = input_tensor_or_list.device
            dtype = input_tensor_or_list.dtype
            tokens = [input_tensor_or_list.item()] if input_tensor_or_list.ndim == 0 else input_tensor_or_list.tolist()

            tokens = self.remove_delete_tokens_list(tokens)

            return torch.tensor(tokens, device=device, dtype=dtype)
        else:
            raise TypeError("Unsupported input type to remove deleted tokens")
