"""tokenizers for transformer. Extends """

import logging
import pickle as pkl
import unicodedata

import torch

from os import PathLike

# from pathlib import Path
from typing import List, Optional, Tuple, Union

from tokenizers import Tokenizer, models, normalizers, pre_tokenizers, decoders, trainers


_LOG = logging.getLogger("vqt2g_logger")


class VQT2GTokenizerBPE:
    """BPE Tokenizer for TTG model"""

    def __init__(
        self,
        text_descriptions: List[str],
        text_vocab_size: int = 1024,
        graph_vocab_size: int = 256,
        special_tokens: Optional[List[str]] = None,
        max_graph_len: int = 128,
        normalise_text: bool = False,
    ):
        """BPE tokenizer for text + graph data

        Args:
            text_descriptions: Tokenizer training data - list of captions for graphs
            text_vocab_size: Number of BPE tokens for vocabulary
            graph_vocab_size: Number of codes in the GVQVAE codebook
            special_tokens: Optional additional special tokens to add
            max_graph_len: Number of codebook codes to represent a graph
            normalise_text: Whether to lowercase text and do some normalising

        """

        # Set text stuff
        self.text_vocab_size = text_vocab_size
        self.text_special_tokens = special_tokens
        self.normalise_text = normalise_text

        self.no_texts = text_vocab_size == 0  # If it's graph-only then ignore any input text

        # Train BPE tok on texts
        if self.normalise_text:
            text_descriptions = [self._normalise_text(i) for i in text_descriptions]

        # Change this so it doesn't add the default 256 tokens in graph-only runs?
        self.train_bpe_tokenizer(text_descriptions)

        # Set graph stuff, make graph token list
        self.graph_vocab_size = graph_vocab_size
        self._make_graph_tokens(self.graph_vocab_size)
        self.max_graph_len = max_graph_len

        self.text_start_token = "<|startoftext|>"
        self.graph_start_token = "<|startofgraph|>"

        # Add text/graph start, then all graph tokens
        self.tokenizer.add_tokens(
            [
                self.text_start_token,
                self.graph_start_token,
                *self.graph_tokens,
            ]
        )

        self.textstart_id = self.text_vocab_size  # Token id for <|startoftext|> token
        self.graphstart_id = self.text_vocab_size + 1  # Token id for <|startofgraph|> token
        self.min_graph_token_id = self.text_vocab_size + 2  # token ids from here are graph tokens

        self.text_token_ids = list(range(self.min_graph_token_id))

        self.graph_token_ids = list(
            range(self.min_graph_token_id, self.min_graph_token_id + self.graph_vocab_size)
        )

        self.set_max_text_tokens(text_descriptions)
        _LOG.info(f"Max number of text tokens in a caption: {self.max_text_token_len}")

        # Position in sequence that graph-start-indicator token should be
        self.graphstart_tok_pos = self.max_text_token_len + 1
        # Total sequence length
        self.total_len = self.max_text_token_len + self.max_graph_len + 2

        # For `bad_words_ids` parameter when generating with huggingface
        # Sometimes it'll generate a text token anyway, unsure why.
        ### The second list is likely unnecessary
        self.bad_ids = [[idx] for idx in self.text_token_ids] + [self.text_token_ids]

    def __len__(self):
        return self.tokenizer.get_vocab_size()

    def _make_graph_tokens(self, num_tokens: int):
        self.graph_tokens = [f"<|graph_{i}|>" for i in range(num_tokens)]

    def vocab(self):
        return self.tokenizer.get_vocab()

    def id_to_token(self, token_id: int) -> str:
        """Convert id to the corresponding token"""
        return self.tokenizer.id_to_token(token_id)

    def token_to_id(self, token: str) -> int:
        """Convert a token to its corresponding id"""
        return self.tokenizer.token_to_id(token)

    def save_tokeniser(self, out_file: Union[str, PathLike]) -> None:
        """Saves tokenizer object to pickle file"""
        with open(out_file, "wb") as f:
            pkl.dump(self, f)

    def train_bpe_tokenizer(self, text_list: List[str]) -> None:
        """Train BPE tokenizer"""
        self.tokenizer = Tokenizer(models.BPE())
        self.tokenizer.normalizer = normalizers.NFKC()
        self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
        self.tokenizer.decoder = decoders.ByteLevel()
        self.tokenizer.enable_padding(pad_token="[PAD]")

        if self.text_special_tokens is None:
            self.text_special_tokens = []
        bpe_trainer = trainers.BpeTrainer(
            vocab_size=self.text_vocab_size,
            initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
            special_tokens=list(self.text_special_tokens),  # + ["[PAD]"],
        )

        self.tokenizer.train_from_iterator(text_list, trainer=bpe_trainer)

    def _normalise_text(self, text: str) -> str:
        """Remove non-ascii characters from text and lowercase it"""
        text = unicodedata.normalize("NFD", text)
        text = text.encode("ascii", "ignore")
        text = text.decode("utf-8")
        return text.lower()

    def set_max_text_tokens(self, text_list: List[str]) -> None:
        """Tokenise the text data, find longest it needs to be"""
        if self.normalise_text:
            text_list = [self._normalise_text(i) for i in text_list]
        encs = self.tokenizer.encode_batch(text_list)
        token_len = len(encs[0].ids)
        self.max_text_token_len = token_len

    def encode_text(self, text: str, add_specials: bool = True, pad: bool = True) -> List[int]:
        """Encode text using the trained BPE tokenzer"""
        if self.no_texts:  ### Added this 1 Sept, haven't tested
            tokens = []
        else:
            if self.normalise_text:
                text = self._normalise_text(text)
            enc = self.tokenizer.encode(text)
            if pad:
                enc.pad(self.max_text_token_len)
            tokens = enc.ids
        # If encoded text is too long cut the end off
        if len(tokens) > self.max_text_token_len:
            _LOG.debug(f"Truncated text input of {len(tokens)} tokens to {self.max_text_token_len}")
            tokens = tokens[: self.max_text_token_len]
        if add_specials:
            tokens = [self.textstart_id, *tokens, self.graphstart_id]
        return tokens

    def tg_pair(self, text: str, graph_encoded: Union[list, torch.Tensor]) -> List[int]:
        """Convert text (as string) and graph (as codebook ids) into single
        token stream"""
        text_tokens = self.encode_text(text, add_specials=True)
        if isinstance(graph_encoded, torch.Tensor):
            graph_encoded = graph_encoded.cpu().view(-1).tolist()
        graph_tokens = [self.to_graph_token_id(i) for i in graph_encoded]

        return text_tokens + graph_tokens

    def tg_dataset(self, texts, graphs):
        """Tokenise each text + graph"""
        assert len(texts) == len(graphs)
        return [self.tg_pair(t, g) for t, g in zip(texts, graphs)]

    def to_graph_token_id(self, codebook_id):
        """Codebook id to graph token id"""
        return codebook_id + self.min_graph_token_id

    def from_graph_token_id(self, token_id):
        """Graph token id to codebook id"""
        if not self.min_graph_token_id <= token_id <= len(self):
            raise ValueError(
                f"Tried to convert '{token_id}' to a graph token, but graph tokens must be in"
                f"({self.min_graph_token_id}-{len(self)})"
            )
        return token_id - self.min_graph_token_id

    def _select_graph_tokens(self, token_ids):
        """Check the sequence is as expected, then remove text and special tokens"""

        # Warn if it has fewer tokens than expected
        seq_len = len(token_ids)
        if seq_len != self.total_len:
            _LOG.warning(f"Got {seq_len} tokens when decoding graph, expected {self.total_len}")
        # It should start with start-of-text indicator token
        if token_ids[0] != self.textstart_id:
            _LOG.warning(f"Expected first token is {self.textstart_id} but got {token_ids[0]}")
        # It should have graph start token at expected graph start position
        gs_token = token_ids[self.graphstart_tok_pos]
        if gs_token != self.graphstart_id:
            _LOG.warning(
                f"Expected graph indicator token id {self.graphstart_id} at position "
                f"{self.graphstart_tok_pos} but found {gs_token}"
            )
        # Keep only the graph tokens
        graph_tokens = token_ids[self.graphstart_tok_pos + 1 :]

        # If invalid tokens it'll fail
        if any(tok < self.min_graph_token_id for tok in graph_tokens):
            raise ValueError(
                f"Non-graph token in graph tokens, all must be between {self.min_graph_token_id}-"
                f"{len(self)}. Got tokens: {graph_tokens}"
            )
        return graph_tokens

    def decode_graph(self, token_ids: List[int]) -> torch.Tensor:
        """Take gen'd seq, grab graph tokens, check they're graph tokens,
        convert to codebook ids

        """

        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.tolist()

        graph_tokens = self._select_graph_tokens(token_ids)

        # Convert from token ids to codebook ids
        graph_vector = [self.from_graph_token_id(t) for t in graph_tokens]
        return torch.tensor(graph_vector)
