import copy
import logging
from typing import Sequence, Dict, Tuple, Optional

import numpy as np
import transformers

from rtfm.alpaca_utils import tokenizer_and_embedding_resize
from rtfm.hf_utils import fetch_auth_token
from rtfm.special_tokens import (
    DEFAULT_PAD_TOKEN,
    IGNORE_INDEX,
)


def fetch_tokenizer(
    pretrained_model_name_or_path: str,
    model_max_length: int,
    use_fast_tokenizer: bool,
    use_auth_token=None,
):
    tokenizer_kwargs = {
        "pretrained_model_name_or_path": pretrained_model_name_or_path,
        "cache_dir": None,
        "model_max_length": model_max_length,
        "padding_side": "right",
        "use_auth_token": use_auth_token,
        "use_fast": use_fast_tokenizer,
    }
    print(f"fetching tokenizer with kwargs {tokenizer_kwargs}")
    return transformers.AutoTokenizer.from_pretrained(**tokenizer_kwargs)


def prepare_tokenizer(
    model,
    pretrained_model_name_or_path: str,
    model_max_length: int,
    use_fast_tokenizer: bool,
    serializer_tokens_embed_fn: Optional[str] = None,
    serializer_tokens: Optional[Dict[str, str]] = None,
    tokenizer=None,
) -> Tuple[transformers.PreTrainedTokenizer, transformers.AutoModelForCausalLM]:
    logging.info(f"setting up tokenizer %s" % pretrained_model_name_or_path)

    if pretrained_model_name_or_path == "yujiepan/llama-2-tiny-random":
        tokenizer = fetch_tokenizer(
            pretrained_model_name_or_path="meta-llama/Llama-2-7b-hf",
            model_max_length=model_max_length,
            use_fast_tokenizer=use_fast_tokenizer,
            use_auth_token=fetch_auth_token(),
        )

    else:
        assert tokenizer is not None

    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.model_max_length = model_max_length

    special_tokens_dict = {}

    if tokenizer.pad_token is None:
        logging.info("no pad token detected; adding pad token")
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN

    assert tokenizer.eos_token is not None
    assert tokenizer.bos_token is not None

    assert (
        serializer_tokens_embed_fn is not None
    ), f"Must provide serializer_tokens_embed_fn if is_train=True."
    tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
        other_tokens_dict=serializer_tokens,
        other_tokens_are_special_tokens=True,
        embed_fn=serializer_tokens_embed_fn,
    )

    return tokenizer, model


def unmasked_token_idxs(tokens):
    """Helper function to fetch indices of unmasked tokens."""
    return np.flatnonzero(tokens != IGNORE_INDEX)


def _tokenize_fn(
    strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [
        _tokenize_fn(strings, tokenizer) for strings in (examples, sources)
    ]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)
