# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os
import time
from typing import List, Literal, Optional, Tuple, TypedDict
from filter.Configuration import Configuration
from filter.CandidateRoute import CandidateRoute, Candidate
import numpy as np
import copy
import torch
# from fairscale.nn.model_parallel.initialize import (
#     get_model_parallel_rank,
#     initialize_model_parallel,
#     model_parallel_is_initialized,
# )

from filter.Model import ModelBuilder
from transformers import DynamicCache

Role = Literal["system", "user", "assistant"]




class Message(TypedDict):
    role: Role
    content: str


class CompletionPrediction(TypedDict, total=False):
    generation: str
    tokens: List[str]  # not required


class ChatPrediction(TypedDict, total=False):
    generation: Message
    tokens: List[str]  # not required


Dialog = List[Message]

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."


def remove_padding_str(string, last_x_words=100):
    tmp_str = string.replace("<pad>", "").replace("<s>", "").replace("</s>", "").replace("<unk>", "")
    tmp_str = tmp_str.replace("<<SYS>>", "").replace("<</SYS>>", "").replace("[INST]", "").replace("[/INST]", "")
    tmp_str = " ".join(tmp_str.split())
    return tmp_str.strip()


class Generator:
    @staticmethod
    def build(
        max_seq_len: int,
        seed: int = 1,
        model_card: str = "georgesung/llama2_7b_chat_uncensored",
    ) -> "Generator":
        """
        Build a Llama instance by initializing and loading a pre-trained model.

        Args:
            ckpt_dir (str): Path to the directory containing checkpoint files.
            tokenizer_path (str): Path to the tokenizer file.
            max_seq_len (int): Maximum sequence length for input text.
            max_batch_size (int): Maximum batch size for inference.
            model_parallel_size (Optional[int], optional): Number of model parallel processes.
                If not provided, it's determined from the environment. Defaults to None.

        Returns:
            Generator: An instance of the Llama class with the loaded model and tokenizer.

        Raises:
            AssertionError: If there are no checkpoint files in the specified directory,
                or if the model parallel size does not match the number of checkpoint files.

        Note:
            This method initializes the distributed process group, sets the device to CUDA,
            and loads the pre-trained model and tokenizer.

        """

        # seed must be the same in all processes
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.set_device(0)
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        np.random.seed(seed)

        start_time = time.time()
        builder = ModelBuilder()
        model = builder.build_model(model_card, max_seq_len)
        print(f"Loaded in {time.time() - start_time:.2f} seconds")
        return Generator(model, model, model_card)

    def __init__(self, model, tokenizer, model_card):
        self.model = model
        self.tokenizer = tokenizer
        self.beams: List[CandidateRoute] = []
        self.model_type = 'hf'
        self.model_card = model_card
        self.input_index = 0
        self.warmup = 0


    @torch.inference_mode()
    def get_embeddings(self, tokens):
        pad_id = self.tokenizer.pad_id
        token_tensor = torch.full((1, len(tokens)), pad_id, dtype=torch.long, device="cuda")
        token_tensor[0, : len(tokens)] = torch.tensor(tokens, dtype=torch.long, device="cuda")
        model_forward = self.model.forward_2(token_tensor, 0)
        embeddings_final_layer = model_forward[-1][-1]
        embeddings_first_layer = model_forward[-1][0]
        return embeddings_first_layer, embeddings_final_layer

    @torch.inference_mode()
    def chat_completion_new(self, dialogs, config):
        inputs = self.tokenizer.dialogs_to_input(dialogs[0])
        original_prompt = dialogs[0][-1]["content"]
        prompt_tokens = inputs.input_ids.to("cuda")
        num_tokens = len(prompt_tokens[0])
        stop_strings = self.model.stop_string
        if config.safety_alpha == 0:
            generation_tokens = self.model.model.generate(inputs=prompt_tokens,
                                                          max_new_tokens=config.max_gen,
                                                          tokenizer=self.tokenizer.tokenizer, stop_strings=stop_strings,
                                                          do_sample=True,
                                                          top_p=config.top_p,
                                                          temperature=config.temperature,
                                                          eos_token_id=self.model.terminators
                                                          )

        elif config.use_cache:
            generation_tokens = self.generate_method_kv_cache(inputs, config, num_tokens, original_prompt)
        else:
            inputs = self.tokenizer.dialogs_to_input(dialogs[0], padding="max_length")
            prompt_tokens = inputs.input_ids.to("cuda")
            generation_tokens = self.generate_tokens_method(self.model, prompt_tokens, config, num_tokens, stop_strings)
        # only take the tokens in the generation that are not part of the prompt
        generation_tokens = generation_tokens[0, num_tokens:]

        generated_text = self.tokenizer.decode(generation_tokens)
        if stop_strings is not None:
            generated_text = generated_text.replace(stop_strings, "")
        num_generated_tokens = len(generation_tokens)
        message_output = Message(role="assistant", content=generated_text)
        return [{"generation": message_output}], num_generated_tokens

    @torch.inference_mode()
    def text_completion(self, dialogs, config):
        input_text = ""
        input_text += dialogs[0][-1]["content"]
        inputs = self.tokenizer.tokenizer(input_text, return_tensors='pt', padding="max_length", max_length=config.max_seq)
        prompt_tokens = inputs["input_ids"].to("cuda")
        attention_mask = inputs["attention_mask"].to("cuda")
        prompt_tokens = prompt_tokens.to("cuda")
        num_tokens = attention_mask.sum(dim=1).tolist()[0]
        if config.safety_alpha == 0:
            generation_tokens = self.model.model.generate(inputs=prompt_tokens[:, :num_tokens], max_new_tokens=config.max_gen
                                                          , use_cache=False, pad_token_id=self.tokenizer.pad_id, stop_strings="Human:",
                                                          tokenizer=self.tokenizer.tokenizer)

        else:
            generation_tokens = self.generate_tokens_method(self.model, prompt_tokens, config, num_tokens, stop_string="\nHuman: ")
        # only take the tokens in the generation that are not part of the prompt
        generation_tokens = generation_tokens[0, num_tokens:]
        message_output = Message(role="assistant", content=self.tokenizer.decode(generation_tokens.tolist()))
        return [{"generation": message_output}]



    def get_candidates_strings(self, candidates):
        return self.tokenizer.decode([[candidate[2].item()] for candidate in candidates])



    def top_p_method(self, model, token_tensor, prev_pos, cur_pos, config, start_prompt_index=None, multi=False):
        num_beams = config.beams
        p = config.top_p
        safety_alpha = config.safety_alpha
        neg_embed = config.negative_embedding_tensor
        # last_x defines how many words are taken into consideration for the embedding
        last_x = config.last_x_words
        lookback = max(cur_pos - last_x, 0)
        logits = model.forward_legacy(token_tensor[:, 0:cur_pos])
        probs = torch.softmax(logits[:, -1] / config.temperature, dim=-1)
        # get num_beams candidates for tokens via top_p sampling
        next_tokens = sample_top_p_multi(probs, p, num_beams)
        copy_token_tensor = copy.deepcopy(token_tensor)
        base_tokens_string = self.tokenizer.decode(copy_token_tensor[0, lookback:cur_pos].tolist())
        base_tokens_string = remove_padding_str(base_tokens_string, config.last_x_words)
        next_tokens_scores = []
        prev_pos_tmp = cur_pos
        # for each token we consider, we add it to the current sentence,
        # then calculate how 'unsafe' sentence is with the new token
        for beam in range(num_beams):
            next_token_value = next_tokens[0][beam]
            copy_token_tensor[:, prev_pos_tmp] = next_token_value
            tokens_string = base_tokens_string + self.tokenizer.decode(next_token_value.tolist())
            embedded_sentence = config.embedder.embed(tokens_string)
            # convert the range of cosine similarity from [-1, 1] to [0, 1]
            cosine_similarity = (1 - torch.cosine_similarity(embedded_sentence, neg_embed).max().item()) / 2
            # integrate the safety score with the probability of the token, total score's range is [0, 1]
            curr_beam_score = (1 - safety_alpha) * probs[:, next_token_value] + safety_alpha * cosine_similarity
            next_tokens_scores.append((beam, curr_beam_score))
        # get the token with the highest total score
        max_beam = max(next_tokens_scores, key=lambda x: x[1])
        return next_tokens[0][max_beam[0]]

    def get_token_scores(self, candidate_embeddings, candidate_probs, candidate_tokens, config):
        candidate_scores = torch.zeros(config.beams)
        if config.operation_mode == "dynamic":
            config.safety_alpha = candidate_probs.max().item()
        for i, candidate in enumerate(candidate_embeddings):
            cosine_similarity = (1 - torch.cosine_similarity(candidate,
                                                             config.negative_embedding_tensor).max().item()) / 2
            candidate_scores[i] = (1 - config.safety_alpha) * candidate_probs[:, i] + config.safety_alpha * cosine_similarity
        chosen_token = candidate_tokens[:, torch.argmax(candidate_scores)]
        return chosen_token

    def iterate_beams_kv_cache(self, candidate_tokens, candidate_probs, original_prompt_str, config):
        candidate_strings_to_embed = []
        for candidate in candidate_tokens[0]:
            string_to_add = original_prompt_str + self.tokenizer.decode(candidate)
            candidate_strings_to_embed.append(string_to_add)
        candidate_embeddings = config.embedder.embed(candidate_strings_to_embed,
                                                     embed_type=config.embed_type)
        if config.embed_type == "token_embeddings":
            candidate_embeddings = [tensor[-1] for tensor in candidate_embeddings]

        chosen_token = self.get_token_scores(candidate_embeddings, candidate_probs, candidate_tokens, config)
        return chosen_token

    def sample_token_kv_cache(self, past_key_values, previous_tokens, output_probabilities, original_prompt_str, config):
        # sorted_probabilities, sorted_indices = torch.sort(output_probabilities, descending=True)
        # candidate_tokens = sorted_indices[:, :config.beams]
        # candidate_probs = sorted_probabilities[:, :config.beams]
        # sample with top_p_multi
        candidate_tokens = sample_top_p_multi(output_probabilities, config.top_p, config.beams)
        candidate_probs = torch.gather(output_probabilities, -1, candidate_tokens)
        if config.lookahead == 0:
            chosen_token = self.iterate_beams_kv_cache(candidate_tokens, candidate_probs, original_prompt_str, config)
        else:
            previous_tokens_matrix = previous_tokens.repeat(candidate_tokens.shape[0], 1)
            previous_tokens_with_candidates = torch.cat([previous_tokens_matrix, candidate_tokens], dim=1)
            generated_outputs = self.model.model.generate(previous_tokens_with_candidates, do_sample=False,
                                                          max_new_tokens=config.lookahead, past_key_values=past_key_values)
            decoded_outputs = self.tokenizer.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
            embedded_outputs = config.embedder.embed(decoded_outputs, token_embeddings=False)
            chosen_token = self.get_token_scores(embedded_outputs, candidate_probs, candidate_tokens, config)
        chosen_token = chosen_token.unsqueeze(0).to("cuda")
        return chosen_token


    def generate_method_kv_cache(self, inputs, config, original_num_tokens, original_prompt_str):
        max_generation = min(config.max_gen, config.max_seq - original_num_tokens)
        generated_ids = inputs.input_ids.to("cuda")
        attention_mask = inputs.attention_mask.to("cuda")
        # cast inputs to cuda
        inputs = {"input_ids": generated_ids, "attention_mask": attention_mask}

        if max_generation == config.max_seq:
            return torch.tensor([self.model.terminators[0]])
        past_key_values = DynamicCache()
        cache_position = torch.arange(generated_ids.shape[1], dtype=torch.int64, device="cuda:0")
        for i in range(max_generation):
            initial_forward = self.model.model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
            token_probs = torch.softmax(initial_forward.logits[:, -1] / config.temperature, dim=-1)
            chosen_token = self.sample_token_kv_cache(past_key_values, generated_ids, token_probs, original_prompt_str, config)

            generated_ids = torch.cat([generated_ids, chosen_token], dim=-1)
            attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.int64, device="cuda:0")], dim=-1)
            inputs = {"input_ids": chosen_token, "attention_mask": attention_mask}
            cache_position = cache_position[-1:] + 1

        return generated_ids

    def generate_tokens_method(self, model, tokens, config, original_num_tokens, stop_string=None):
        max_gen_len = config.max_gen
        max_seq_len = config.max_seq
        max_generation_allowed = min(original_num_tokens + max_gen_len, max_seq_len)
        if max_generation_allowed == max_seq_len:
            return tokens
        if stop_string is not None:
            stop_strings_tokens_lens = self.tokenizer.tokenizer(stop_string, return_tensors='pt', return_length=True).length
            stop_string_len = stop_strings_tokens_lens[0]
        else:
            stop_string_len = 0
        for cur_pos in range(original_num_tokens, max_generation_allowed):
            next_token = self.top_p_method(model, tokens, None, cur_pos, config)
            tokens[:, cur_pos] = next_token
            if stop_string is not None:
                if cur_pos >= stop_string_len:
                    curr_string = self.tokenizer.decode(tokens[0, cur_pos - stop_string_len:cur_pos].tolist())
                    if curr_string == stop_string:
                        break
        # only take the tokens in the generation that are not part of the prompt
        generation_tokens = tokens[:, :cur_pos - stop_string_len]
        return generation_tokens



def sample_top_p_multi(probs, p, num_beams):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_tokens = torch.multinomial(probs_sort, num_samples=num_beams, replacement=False)
    next_tokens = torch.gather(probs_idx, -1, next_tokens)
    return next_tokens


def sample_top_p(probs, p):
    """
    Perform top-p (nucleus) sampling on a probability distribution.

    Args:
        probs (torch.Tensor): Probability distribution tensor.
        p (float): Probability threshold for top-p sampling.

    Returns:
        torch.Tensor: Sampled token indices.

    Note:
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
        exceeds the threshold p. The distribution is renormalized based on the selected tokens.

    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token
