#
# This file uses code from lm-vocab-trimmer by Asahi Ushio, which is licensed
# under the MIT License.
#
# Copyright (c) 2023 asahi417
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Prune the vocabulary of Llama models"""

import json
import os
import tempfile
from collections import defaultdict
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union

import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedTokenizerFast,
)

from efficient_heads.pipeline import GenerationPipeline


def update_token_freq(tokens: List[str], fq: Dict[str, int]) -> Dict[str, int]:
    """Update frequencies for tokens.

    :param tokens:
        List of tokens.
    :param fq:
        A dictionary with frequencies for each token.
    :return:
        The updated dictionary of frequencies.
    """
    for w in tokens:
        fq[w] += 1
    return fq


# pylint: disable=too-many-locals
def mine_vocab(
    tokenizer: PreTrainedTokenizerFast,
    vocab_size: int = 32000,
    dataset_name: str = "tatsu-lab/alpaca",
    dataset_column: str = "text",
    dataset_split: str = "train",
    chunk: int = 1000,
) -> Dict[str, int]:
    """Mine tokens using a dataset to get a smaller vocabulary.

    :param tokenizer:
        A pretrained tokenizer to use.
    :param vocab_size:
        The target vocabulary size, defaults to 32000.
    :param dataset_name:
        A dataset to use, defaults to 'tatsu-lab/alpaca'.
    :param dataset_column:
        The column in the dataset to use, defaults to 'text'
    :param dataset_split:
        The dataset split to use, defaults to 'train'.
    :param chunk:
        The size of chunks for data processing, defaults to 1000.
    :return:
        A new vocabulary.
    """
    # Special tokens start at the end of the vocabulary
    special_token_start_id = tokenizer.vocab_size

    # Load the dataset and compute token frequencies in batches
    if "xnli" in dataset_name:
        dataset = load_dataset(
            dataset_name, "all_languages", split=dataset_split
        )
    else:
        dataset = load_dataset(dataset_name, split=dataset_split)

    batch = []
    fq = defaultdict(int)

    for t in tqdm(dataset):
        if "xnli" in dataset_name:
            for prompt in t[dataset_column].values():
                batch.append(prompt)
        else:
            batch.append(t[dataset_column])
        if len(batch) >= chunk:
            fq = update_token_freq(
                list(chain(*tokenizer(batch)["input_ids"])), fq
            )
            batch = []
    if len(batch) != 0:
        fq = update_token_freq(list(chain(*tokenizer(batch)["input_ids"])), fq)

    freq_raw = fq

    # Collect the raw frequencies as (token, frequency, id)
    freq = []
    for k, v in freq_raw.items():
        if int(k) < special_token_start_id:
            token_str = tokenizer.convert_ids_to_tokens(int(k))
            freq.append((token_str, v, int(k)))

    # Sort in descending order by frequency and keep only upto the target size
    freq = sorted(freq, key=lambda x: x[1], reverse=True)
    freq = freq[:vocab_size]

    # Create a new vocabulary with keys as the token and ids as the value
    new_vocab = {x[0]: x[2] for x in freq}

    # Add tokens upto the target size in the order of appearance
    diff = vocab_size - len(new_vocab.items())
    num_added_tokens = 0

    for i in range(tokenizer.vocab_size):
        if num_added_tokens == diff:
            break
        new_token = tokenizer.convert_ids_to_tokens(i)
        if new_token not in new_vocab:
            new_vocab[new_token] = i
            num_added_tokens += 1

    return new_vocab


# pylint: disable=too-many-locals
def remap_token_ids(
    vocab: Dict[str, int], special_token_start_id: int, num_special_tokens: int
) -> Dict[int, int]:
    """Remap token ids in the vocabulary to be continuous integers.

    The updated vocabulary dictionary with the ids might be as {"a": 1, "b": 4}
    which should be remapped to {"a":1, "b":2} to be continuous.

    :param vocab:
        The vocabulary dict with the original ids as values and tokens as keys.
    :param special_token_start_id:
        The token id for special tokens, e.g., 128000, 128001.
    :param num_special_tokens:
        The number of special tokens to add at the end.
    :return: A mapping from the old ids to new continuous ids.
    """
    # Sort the vocabulary dict according to old ids
    sorted_vocab = sorted(vocab.items(), key=lambda x: x[1])
    vocab_size = len(vocab.items())

    # Create a map between old indices and new indices
    new_id_map = {}

    for i, (_, old_id) in enumerate(sorted_vocab):
        new_id_map[old_id] = i

    for i in range(num_special_tokens):
        new_id_map[special_token_start_id + i] = vocab_size + i

    return new_id_map


def prune_model_vocabulary(
    model: AutoModelForCausalLM,
    new_vocab: Dict[str, int],
    special_token_start_id: int,
    num_special_tokens: int,
):
    """Prune the model vocabulary by updating the first and last layer.

    :param model:
        A pretrained transformer model.
    :param new_vocab:
        The new vocabulary.
    :param special_token_start_id:
        The starting id for special tokens.
    :param num_special_tokens:
        The number of special tokens.
    :return:
        The updated model.
    """
    # Create a map between old indices and new indices
    new_id_map = remap_token_ids(
        new_vocab, special_token_start_id, num_special_tokens
    )
    new_vocab_size = len(new_id_map)
    old_input_emb = model.get_input_embeddings().weight.data
    old_vocab_size, emb_dim = old_input_emb.shape

    # Create an index array to reshuffle the input and output layer neurons
    index_array = [None] * new_vocab_size

    for old_id, new_id in new_id_map.items():
        if old_id < old_vocab_size:
            index_array[new_id] = old_id
        else:
            index_array[new_id] = None

    # Create a new embedding layer of the target vocabulary shape
    new_input_emb = torch.zeros(
        (new_vocab_size, emb_dim),
        dtype=old_input_emb.dtype,
        device=old_input_emb.device,
    )

    # Update the embedding layer with the original weights
    for new_id, old_id in enumerate(index_array):
        if old_id is None:
            new_input_emb[new_id] = torch.empty(
                emb_dim,
                dtype=old_input_emb.dtype,
                device=old_input_emb.device,
            ).normal_(mean=0.0, std=0.02)
        else:
            new_input_emb[new_id] = old_input_emb[old_id]

    # Set the input embedding layer of the model
    new_input_layer = torch.nn.Embedding.from_pretrained(
        new_input_emb, freeze=False
    )
    model.set_input_embeddings(new_input_layer)

    # Update the output layer
    if model.get_output_embeddings() is not None:
        old_output_emb = model.get_output_embeddings().weight.data
        out_vocab_size, out_emb_dim = old_output_emb.shape
        new_output_emb = torch.zeros(
            (new_vocab_size, out_emb_dim),
            dtype=old_output_emb.dtype,
            device=old_output_emb.device,
        )
        for new_id, old_id in enumerate(index_array):
            if old_id is None or old_id >= out_vocab_size:
                new_output_emb[new_id] = torch.empty(
                    out_emb_dim,
                    dtype=old_output_emb.dtype,
                    device=old_output_emb.device,
                ).normal_(mean=0.0, std=0.02)
            else:
                new_output_emb[new_id] = old_output_emb[old_id]
        out_linear = torch.nn.Linear(out_emb_dim, new_vocab_size, bias=False)
        out_linear.weight.data = new_output_emb
        model.set_output_embeddings(out_linear)

    vocab_size = len(new_vocab)
    model = update_llama_config(model, vocab_size, num_special_tokens)

    return model


def get_new_vocab_and_merges(
    mined_vocab: Dict[str, int],
    new_vocab_size: int,
    tokenizer: PreTrainedTokenizerFast,
) -> Union[Dict[str, int], List[Tuple[str]]]:
    """Get a new vocabulary and merges to update the tokenizer.

    :param mined_vocab:
        The mined vocabulary with indices from the old tokenizer.
    :param new_vocab_size:
        The size of the new vocabulary.
    :param tokenizer:
        The old tokenizer.
    :return: A new vocabulary and list of merges for the updated tokenizer.
    """
    old_state = json.loads(tokenizer.backend_tokenizer.model.__getstate__())

    sorted_tokens_list = [
        token
        for token, token_id in sorted(mined_vocab.items(), key=lambda x: x[1])
    ]
    custom_tokens = sorted_tokens_list[:new_vocab_size]

    # Extract merges (list of 2-element tuples). If not present, just keep
    # empty.
    merges_data = old_state.get("merges", [])
    if isinstance(merges_data, tuple):
        merges_data = list(merges_data)  # Convert tuple to list if needed

    # -------------------------------------------------------------------------
    # Filter merges to remove any that reference tokens not in 'custom_tokens'
    # -------------------------------------------------------------------------
    custom_token_set = set(custom_tokens)
    new_merges = []
    for merge_pair in merges_data:
        # Typically each 'merge_pair' is a 2-element list or tuple: ("t", "h")
        if isinstance(merge_pair, (list, tuple)) and len(merge_pair) == 2:
            a, b = merge_pair
            ab = a + b
            # Keep this merge only if: a in vocab, b in vocab, and "ab" is
            # also in vocab
            if (
                a in custom_token_set
                and b in custom_token_set
                and ab in custom_token_set
            ):
                # Ensure it's a tuple, as required
                new_merges.append((a, b))

    # Create custom vocabulary while preserving order
    vocab = {token: i for i, token in enumerate(custom_tokens)}

    return vocab, new_merges


def update_llama_config(
    model: AutoModelForCausalLM,
    vocab_size: int,
    num_special_tokens: int,
) -> AutoModelForCausalLM:
    """Update the configuration of the model with custom tokens.

    :param model: The Llama model to update.
    :param vocab_size: Target vocabulary size.
    :param num_special_tokens: The number of special tokens.
    :return: The model with updated config.
    """
    updated_config = {
        "vocab_size": vocab_size + num_special_tokens,
        "bos_token_id": vocab_size,
        "eos_token_id": [
            # These are hardcoded from the tokenizer
            vocab_size + 1,
            vocab_size + 8,
            vocab_size + 9,
        ],
    }
    model.config.update(updated_config)

    updated_config = {
        "bos_token_id": vocab_size,
        "eos_token_id": [
            # These are hardcoded from the tokenizer
            vocab_size + 1,
            vocab_size + 8,
            vocab_size + 9,
        ],
    }
    model.generation_config.update(**updated_config)
    return model


def update_tokenizer_config(
    pretrained_path: str, old_vocab_size: int, new_vocab_size: int
) -> None:
    """Update the tokenizer configuration in the path.

    :param pretrained_path: The path to the tokenizer.
    :param old_vocab_size: The old vocabulary size.
    :param new_vocab_size: The new vocabulary size.
    """
    path_to_file = os.path.join(pretrained_path, "tokenizer_config.json")
    with open(path_to_file, encoding="utf-8") as f:
        data = json.load(f)

    # Step 2: Get the dict we want to modify
    #         e.g. data["added_tokens_decoder"]
    decoder_dict = data.get("added_tokens_decoder", {})

    # Step 3: Build a new dict with shifted keys
    #  If you want 128000 -> 32000, 128001 -> 32001, etc., the offset is -96000.
    shift_amount = new_vocab_size - old_vocab_size  # = -96000
    updated_decoder = {}

    for old_key_str, token_info in decoder_dict.items():
        old_key_int = int(old_key_str)
        new_key_int = (
            old_key_int + shift_amount
        )  # e.g. 128000 + (-96000) -> 32000
        updated_decoder[str(new_key_int)] = token_info

    # Replace the old decoder with the new one
    data["added_tokens_decoder"] = updated_decoder

    # Step 4: Write the updated JSON to an output file
    with open(path_to_file, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)


def update_tokenizer(
    pretrained_path: str,
    old_vocab_size: int,
    new_vocab_size: int,
    new_vocab: Dict[str, int],
    tokenizer: PreTrainedTokenizerFast,
) -> None:
    """Update the tokenizer in place

    :param pretrained_path: The path to the tokenizer.
    :param old_vocab_size: Old vocabulary size.
    :param new_vocab_size: New vocabulary size.
    :param new_vocab: The new vocabulary dictionary.
    :param tokenizer: The original tokenizer.
    """
    path_to_file = os.path.join(pretrained_path, "tokenizer.json")
    with open(path_to_file, encoding="utf-8") as f:
        data = json.load(f)

    # Calculate the shift
    # 128000 -> 32000 => offset = 32000 - 128000 = -96000
    shift_amount = new_vocab_size - old_vocab_size

    # 3) Update each token in "added_tokens"
    for token_info in data.get("added_tokens", []):
        if "id" in token_info:
            token_info["id"] = token_info["id"] + shift_amount

    post_processor = data.get("post_processor", {})
    processors = post_processor.get("processors", [])

    # pylint: disable=too-many-nested-blocks
    for processor in processors:
        if processor.get("type") == "TemplateProcessing":
            # Inside "TemplateProcessing", there's a "special_tokens" dict
            special_tokens = processor.get("special_tokens", {})
            # Each key is a token string, value is a dict with "ids" array
            for _, token_info in special_tokens.items():
                if "ids" in token_info:
                    old_ids = token_info["ids"]
                    # Shift each ID if needed
                    new_ids = []
                    for _id in old_ids:
                        if _id >= old_vocab_size:  # or any custom condition
                            new_ids.append(_id + shift_amount)
                        else:
                            new_ids.append(_id)
                    token_info["ids"] = new_ids

    new_vocab, new_merges = get_new_vocab_and_merges(
        new_vocab, new_vocab_size, tokenizer
    )

    # Update the models "vocab" and "merges" field
    data["model"]["vocab"] = new_vocab
    data["model"]["merges"] = new_merges

    # Write out the updated JSON
    with open(path_to_file, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

    update_tokenizer_config(
        pretrained_path,
        old_vocab_size=old_vocab_size,
        new_vocab_size=new_vocab_size,
    )


def vocab_prune_model(
    model_name_or_path: str,
    vocab_size: int,
    output_dir: str,
    dataset_name: str = "tatsu-lab/alpaca",
    dataset_column: str = "text",
    dataset_split: str = "train",
) -> None:
    """Prune and save the model and tokenizer.

    :param model_name_or_path:
        A transformers model path or HuggingFace model id.
    :param vocab_size:
        The target vocabulary size. Note that it will add the special tokens
        automatically so you do not have to specify it as 32256, rather 32000.
    :param output_dir:
        Output directory to save the model and tokenizer.
    :param dataset_name:
        A dataset to use for vocab pruning, e.g., "tatsu-lab/alpaca".
    :param dataset_column:
        The dataset column to use for pruning, e.g., "text".
    :param dataset_split:
        The dataset split to use, e.g., "train", "val".
    """
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

    # Define special parameters
    special_token_start_id = tokenizer.vocab_size
    num_special_tokens = len(tokenizer.added_tokens_decoder)

    # Mine new vocabulary
    new_vocab = mine_vocab(
        tokenizer,
        vocab_size=vocab_size,
        dataset_name=dataset_name,
        dataset_column=dataset_column,
        dataset_split=dataset_split,
    )

    # Save original tokenizer and update it
    tokenizer.save_pretrained(output_dir)

    update_tokenizer(
        pretrained_path=output_dir,
        old_vocab_size=tokenizer.vocab_size,
        new_vocab_size=vocab_size,
        new_vocab=new_vocab,
        tokenizer=tokenizer,
    )

    # Prune model
    model = prune_model_vocabulary(
        model, new_vocab, special_token_start_id, num_special_tokens
    )

    # Save pruned model
    model.save_pretrained(output_dir)
    print(f"Pruned model and tokenizer saved to: {output_dir}")


def get_pruned_pipeline(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    vocab_size: int = 32000,
    cache_dir: Optional[str] = None,
    device_map: str = "cuda",
):
    """
    Get the generation pipeline for vocabulary pruned models.

    :param model_id:
        The model id, defaults to "meta-llama/Llama-3.2-1B-Instruct"
    :param vocab_size:
        The vocabulary size.
    :param cache_dir:
        The cache directory if a pruned model already exists.
    :param device_map:
        The device the model should be loaded at.

    :return: A generation pipeline for vocabulary pruned models.
    """
    if not cache_dir:
        with tempfile.TemporaryDirectory() as temp_cache_dir:
            vocab_prune_model(
                model_id,
                vocab_size=vocab_size,
                output_dir=temp_cache_dir,
            )
            tokenizer = AutoTokenizer.from_pretrained(temp_cache_dir)
            model = AutoModelForCausalLM.from_pretrained(
                temp_cache_dir,
                torch_dtype=torch.bfloat16,
                device_map=device_map,
            )
    else:
        tokenizer = AutoTokenizer.from_pretrained(cache_dir)
        model = AutoModelForCausalLM.from_pretrained(
            cache_dir, torch_dtype=torch.bfloat16, device_map=device_map
        )

    generation_pipeline = GenerationPipeline(
        model.model,
        model.lm_head,
        tokenizer=tokenizer,
    )
    return generation_pipeline


# This can be merged with the other function to do in place model creation
def vocab_prune_model_and_tokenizer(
    model_id: str,
    vocab_size: int,
    dataset_name: str = "tatsu-lab/alpaca",
    dataset_column: str = "text",
    dataset_split: str = "train",
    device: str = "cuda",
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    """Prune and save the model and tokenizer.

    :param model_id:
        A transformers model path or HuggingFace model id.
    :param vocab_size:
        The target vocabulary size. Note that it will add the special tokens
        automatically so you do not have to specify it as 32256, rather 32000.
    :param dataset_name:
        A dataset to use for vocab pruning, e.g., "tatsu-lab/alpaca".
    :param dataset_column:
        The dataset column to use for pruning, e.g., "text".
    :param dataset_split:
        The dataset split to use, e.g., "train", "val".
    :param device:
        The device to use "cuda".
    """
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Define special parameters
    special_token_start_id = tokenizer.vocab_size
    num_special_tokens = len(tokenizer.added_tokens_decoder)

    # Mine new vocabulary
    new_vocab = mine_vocab(
        tokenizer,
        vocab_size=vocab_size,
        dataset_name=dataset_name,
        dataset_column=dataset_column,
        dataset_split=dataset_split,
    )

    # Prune model
    model = prune_model_vocabulary(
        model, new_vocab, special_token_start_id, num_special_tokens
    )

    # Save original tokenizer, update it and reload
    with tempfile.TemporaryDirectory() as output_dir:
        tokenizer.save_pretrained(output_dir)
        update_tokenizer(
            pretrained_path=output_dir,
            old_vocab_size=tokenizer.vocab_size,
            new_vocab_size=vocab_size,
            new_vocab=new_vocab,
            tokenizer=tokenizer,
        )
        # Reload tokenizer
        tokenizer = AutoTokenizer.from_pretrained(output_dir)

    return model, tokenizer


class VocabPruningHead(torch.nn.Module):
    """
    Vocab pruning head that gives tokens by retokenizing with a base tokenizer.

    This is needed since vocab pruning creates a new tokenizer where the
    position of the tokens are changed due to dropping of old tokens. If we
    want to compare the generated next tokens with a baseline, we need to get
    the tokens that the original tokenizer would have produced for the
    generated text.
    """

    def __init__(
        self,
        model_name_or_path: str,
        vocab_size: int,
        dataset_name="tatsu-lab/alpaca",
        dataset_column="text",
        dataset_split="train",
        device="cuda",
    ):
        super().__init__()
        pruned_model, pruned_tokenizer = vocab_prune_model_and_tokenizer(
            model_name_or_path,
            vocab_size,
            dataset_name,
            dataset_column,
            dataset_split,
            device=device,
        )
        pruned_model.to(torch.bfloat16)
        self.pruned_tokenizer = pruned_tokenizer
        self.original_tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path
        )
        self.head = pruned_model.lm_head

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Forward method."""
        return self.head(hidden_states)

    @torch.no_grad()
    def get_next_token(
        self,
        hidden_states: torch.Tensor,
        do_sample: bool = False,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """
        Returns the next predicted token (according to the original tokenizer).

        :param hidden_states:
            The output of the model body.
        :param do_sample:
            Whether to sample the next token according to probabilities,
            or simply return the most probable.
        :param temperature:
            The temperature to use in the softmax
            (both the softmax in cluster probabilities and for the
            softmax in token probabilities).
            Only relevant when `do_sample` is ``True``.
        """
        # Get logits from pruned head
        logits = self.forward(hidden_states)  # (B, T, V_pruned)
        logits = logits[:, -1, :]  # Only use the last position

        if do_sample:
            probs = torch.softmax(logits / temperature, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1)  # (B, 1)
        else:
            next_token_id = torch.argmax(
                logits, dim=-1, keepdim=True
            )  # (B, 1)

        # Decode to text using the pruned tokenizer
        generated_text = self.pruned_tokenizer.decode(
            next_token_id[0].tolist()
        )

        # Retokenize using original tokenizer
        original_tokens = self.original_tokenizer(
            generated_text, return_tensors="pt"
        ).input_ids
        next_token = original_tokens[:, 1:2].to(hidden_states.device)
        return next_token


def get_vocab_pruning_model_and_tokenizer(
    model_id: str,
    vocab_size: int,
    dataset_name: str = "tatsu-lab/alpaca",
    dataset_column: str = "text",
    dataset_split: str = "train",
    device: str = None,
):
    """Get a new model and tokenzier with VocabPrunedHead as the lm_head."""
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    head = VocabPruningHead(
        model_id,
        vocab_size,
        dataset_name,
        dataset_column,
        dataset_split,
        device=device,
    )

    model = AutoModelForCausalLM.from_pretrained(model_id)
    model.lm_head = head

    tokenizer = head.pruned_tokenizer

    return model, tokenizer
