import re
import os
import json
import dataclasses as dc

from litgpt.tokenizer import Tokenizer as LitTokenizer
from tokenizers import Tokenizer as HFTokenizer
from tokenizers.models import BPE
from tokenizers.decoders import BPEDecoder
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer

from pathlib import Path
from enum import StrEnum
from copy import deepcopy
from typing import Literal, Iterable, Mapping, Any

from torch import Tensor, device


def _listfiles(root: str | Path):
    """
    Recursively lists all files in a directory.
    
    Args:
    root (str): The root directory to start listing files from.
    
    Returns:
    List[str]: A list of file paths relative to the root directory.
    """
    file_list = []
    
    for dirpath, _, filenames in os.walk(root):
        for filename in filenames:
            # Construct the relative file path
            relative_path = os.path.relpath(os.path.join(dirpath, filename), root)
            file_list.append(relative_path)
    
    return file_list


class Vocabulary:

    def __init__(self,
        *tokens: str,
        **kwtokens: str | None
    ):
        
        self._kwtokens: dict[str, str] = {}
        self.tokens: list[str] = []

        self.add_kwtokens(kwtokens)
        self.add_tokens(tokens)
    
    def add_tokens(self, tokens: Iterable[str]):
        for token in tokens:
            if token not in self.tokens:
                self.tokens.append(token)
        
    def add_kwtokens(self, kwtokens: Mapping[str, str | None], override: bool = False):
        for name, token in kwtokens.items():
            if token is None:
                continue
            if token not in self.tokens:
                self.tokens.append(token)
            if override or name not in self._kwtokens:
                self._kwtokens[name] = token
    
    def __len__(self):
        return len(self.tokens)
    
    def __repr__(self) -> str:
        token2kw = {v: k for k, v in self._kwtokens.items()}
        items = []
        for tk in self.tokens:
            if (kw := token2kw.get(tk)):
                items.append('%s: %s' % (kw, tk))
            else:
                items.append(tk)
        return '{' + (', '.join(items)) + '}'
    
    def incorporate(self, other: 'Vocabulary', override_kwtokens: bool = False):
        self.add_kwtokens(other._kwtokens, override_kwtokens)
        self.add_tokens(other.tokens)

    @staticmethod
    def combine(*parts: 'Vocabulary', override_kwtokens: bool = False):
        tokens = Vocabulary()
        for part in parts:
            tokens.add_kwtokens(part._kwtokens, override_kwtokens)
        for part in parts:
            tokens.add_tokens(part.tokens)
        return tokens

    def __getattr__(self, name: str) -> str | None:
        return self._kwtokens.get(name)
    
    def __getitem__(self, idx: str | int) -> str:
        if isinstance(idx, str):
            return self._kwtokens[idx]
        else:
            return self.tokens[idx]

    def copy(self):
        return Vocabulary(*self.tokens, **self._kwtokens)
    
    def update_tokenizer_config(self, d: dict):
        for k, token in self._kwtokens.items():
            d[k + '_token'] = token

    def add_to_tokenizer(self, tokenizer: HFTokenizer, special=True):
        if special:
            tokenizer.add_special_tokens(self.tokens)
        else:
            tokenizer.add_tokens(self.tokens)

    def add_to_checkpoint(self, checkpoint_dir: str | Path, outdir: str | Path | None = None, special=True):
        if isinstance(checkpoint_dir, str):
            checkpoint_dir = Path(checkpoint_dir)
        
        if outdir is None:
            outdir = checkpoint_dir
        else:
            if isinstance(outdir, str):
                outdir = Path(outdir)
            outdir.mkdir(parents=True, exist_ok=True)

        with open(checkpoint_dir / "tokenizer_config.json", 'rt') as f:
            config = json.load(f)
            assert isinstance(config, dict)
        
        self.update_tokenizer_config(config)
        tokenizer: HFTokenizer = HFTokenizer.from_file(str(checkpoint_dir / "tokenizer.json"))
        self.add_to_tokenizer(tokenizer)

        tokenizer.save(str(outdir / "tokenizer.json"))
        with open(outdir / "tokenizer_config.json", 'wt') as f:
            json.dump(config, f, indent=4)

    @staticmethod
    def from_tokenizer_config(checkpoint_dir: str | Path):
        if isinstance(checkpoint_dir, str):
            checkpoint_dir = Path(checkpoint_dir)

        with open(checkpoint_dir / 'tokenizer_config.json', 'rt') as f:
            tokenizer_config: dict[str, Any] = json.load(f)
            assert isinstance(tokenizer_config, dict)

        return Vocabulary(**{
            k.removesuffix('_token'): v
            for k, v in tokenizer_config.items() if k.endswith('_token')
        })


def train_tokenizer(
    files: list[Path | str] | str,
    root: str | Path = './',
    out_dir: Path | str | None = None,
    alg: Literal['BPE'] = 'BPE',
    special_tokens: Vocabulary = Vocabulary(),
    tokenizer_class: str | None = None,
    kwargs: dict | None = None
) -> HFTokenizer:
    """
    Train a tokenizer using the specified algorithm and parameters.

    Args:
        files (Union[List[Path], List[str], str]): A list of file paths or a pattern to match files for training.
        root (Union[str, Path], optional): The root directory to search for files. Defaults to './'.
        checkpoint_dir (str | None, optional): The path to save the trained tokenizer. If None, the tokenizer is not saved. Defaults to None.
        alg (Literal['BPE'], optional): The algorithm to use for tokenization. Currently only 'BPE' is supported. Defaults to 'BPE'.
        tokenizer_class (str, optional): The tokenizer class saved to the config file. Defaults to `alg + "Tokenizer"`.
        options: keyword argumengts of the tokenization algorithm.

    Raises:
        NotImplementedError: If the specified algorithm is not supported.

    Returns:
        HFTokenizer: The trained tokenizer instance.
    """

    if isinstance(root, str):
        root = Path(root)
    
    if isinstance(files, str):  # pattern
        pattern = re.compile(files)
        files = [f for f in _listfiles(root) if re.match(pattern, f)]

    kwargs = kwargs or {}
    model_kwargs = kwargs.get('model', {})
    trainer_kwargs = kwargs.get('trainer', {})
    decoder_kwargs = kwargs.get('decoder', {})
    pre_tokenizer_kwargs = kwargs.get('pre_tokenizer', {})

    if alg == 'BPE':
        tokenizer = HFTokenizer(BPE(unk_token=special_tokens.unk, **model_kwargs))
        trainer = BpeTrainer(
            special_tokens=special_tokens.tokens,  # type: ignore
            **trainer_kwargs  # type: ignore
        )
        tokenizer.decoder = BPEDecoder(**decoder_kwargs)  # type: ignore
        tokenizer.pre_tokenizer = Whitespace(**pre_tokenizer_kwargs)  # type: ignore
    else:
        raise NotImplementedError(f"'{alg}' is not a supported tokenization algorithm.")
    
    tokenizer.train(files=[str(root / f) for f in files], trainer=trainer)

    if isinstance(out_dir, str):
        out_dir = Path(out_dir)
    
    if out_dir is not None:
        os.makedirs(out_dir, exist_ok = True)
        tokenizer.save(str(out_dir / 'tokenizer.json'))
        tokenizer_class = tokenizer_class or alg + "Tokenizer"
        tokenizer_config = {'tokenizer_class': tokenizer_class}
        special_tokens.update_tokenizer_config(tokenizer_config)
        with (out_dir / 'tokenizer_config.json').open('wt') as f:
            json.dump(tokenizer_config, f, indent=4)

    return tokenizer


class Tokenizer(LitTokenizer):

    def __init__(self, checkpoint_dir: Path | str) -> None:
        super().__init__(checkpoint_dir)
        assert isinstance(self.processor, HFTokenizer) and self.backend == "huggingface"
        self.processor: HFTokenizer
        self.backend: Literal["huggingface"]

    def decode(self, tensor: Tensor, skip_special_tokens=True) -> str:
        tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
        return self.processor.decode(tokens, skip_special_tokens)
    
    @property
    def vocab_size(self):
        return self.processor.get_vocab_size(with_added_tokens=True)
