from typing import List
import torch
import copy
from types import SimpleNamespace
import torch.nn.functional as F


"""
This file contains most of the logic for the CandidateRoute class, which is used in the Beam-Search algorithm.
Also contains the multiple-sample top-p sampling algorithm.
"""

class Candidate:
    def __init__(self, probs, indices, score):
        self.probs = probs
        self.indices = indices
        self.score = score


class CandidateRoute:

    def __init__(self, num_beams, tokens, safety_alpha, neg_embed, safety_lambda,
                 length_beta, top_p=0.9, temperature=0.6, embedder=None, tokenizer=None, last_x_words=100,
                 aggregation_mode="log", negative_embedding_normalized=None):
        self.candidates: List[Candidate] = []
        self.num_beams = num_beams
        self.token_tensor = copy.deepcopy(tokens)
        self.beam_total_prob = 0
        self.safety_alpha = safety_alpha
        self.length = 1
        self.length_beta = length_beta
        self.score = 1
        self.embedder = embedder
        self.use_embedder = embedder is not None
        self.unsafe_embed = neg_embed
        self.aggregated_probs = 0 if aggregation_mode == "log" else 1
        self.safety_lambda = safety_lambda
        self.top_p = top_p
        self.temperature = temperature
        self.last_x_words = last_x_words
        self.aggregation_mode = aggregation_mode
        self.unsafe_embed_norm = negative_embedding_normalized



    def add_candidate(self, candidate: Candidate, cur_pos):
        self.candidates.append(candidate)
        # check if indices fit in token_tensor, if not do nothing
        if cur_pos >= self.token_tensor.shape[1]:
            return
        self.token_tensor[:, cur_pos] = candidate.indices
        # manage cache of probabilities
        self.length += 1
        self.beam_total_prob = candidate.probs
        if self.aggregation_mode == "log":
            self.aggregated_probs = self.aggregated_probs + torch.log(candidate.probs).item()
        elif self.aggregation_mode == "multiply":
            self.aggregated_probs = self.aggregated_probs * candidate.probs
        # self.aggregated_probs = self.aggregated_probs * candidate.probs
        self.score = candidate.score

    def remove_padding_str(self, string):
        tmp_str = string.replace("<pad>", "").replace("<s>", "").replace("</s>", "").replace("<unk>", "")
        tmp_str = tmp_str.replace("<<SYS>>", "").replace("<</SYS>>", "").replace("[INST]", "").replace("[/INST]", "")
        tmp_str = tmp_str.replace("### HUMAN:", "").replace("### RESPONSE", "")
        tmp_str = " ".join(tmp_str.split())
        return tmp_str.strip()



    def inference_legacy(self, prev_pos, cur_pos, model, warming_up=True):
        # check if indices fit in token_tensor, if not do nothing
        if self.token_tensor.shape[1] <= (cur_pos + 1):
            indices = torch.tensor([[self.tokenizer.eos_id] * self.num_beams], device="cuda")
            probability = torch.tensor([[self.beam_total_prob] * self.num_beams], device="cuda")
            safety_penalty = torch.tensor([[0] * self.num_beams], device="cuda")
            total_scores = self.score + safety_penalty
            return probability, indices, safety_penalty, total_scores
        else:
            logits = model.forward(self.token_tensor[:, 0:cur_pos], prev_pos)
            if self.top_p > 0:
                probs = torch.softmax(logits[:, -1] / self.temperature, dim=-1)
                next_tokens = self.sample_top_p_multi(probs, self.num_beams)
            else:
                probs = torch.softmax(logits[:, -1], dim=-1)
                next_tokens = torch.topk(probs, self.num_beams, dim=-1)
            next_tokens_embed = []
            if self.safety_alpha > 0 and not warming_up:
                copy_token_tensor = copy.deepcopy(self.token_tensor)
                base_tokens_string = self.tokenizer.decode(copy_token_tensor[0, 0:cur_pos].tolist())
                base_tokens_string = self.remove_padding_str(base_tokens_string)
                stacked_token_tensor = torch.stack([copy_token_tensor * 1 for _ in range(self.num_beams)],
                                                   dim=0).squeeze(1)
                look_back = self.length - 1
                for beam in range(self.num_beams):
                    if self.use_embedder:
                        copy_token_tensor[:, cur_pos] = next_tokens.indices[0][beam]
                        tokens_string = base_tokens_string + " " + self.tokenizer.decode(next_tokens.indices[0][beam].tolist())
                        next_tokens_embed.append(self.embedder.embed(tokens_string))

                    else:
                        copy_token_tensor[:, cur_pos] = next_tokens.indices[0][beam]
                        _, outputs = model.forward_2(copy_token_tensor[:, look_back:cur_pos + 1], prev_pos)
                next_tokens_embed.append(outputs[-1][:, -1, :])
                next_tokens_embed = torch.stack(next_tokens_embed).squeeze(1)
                safety_penalty = torch.cosine_similarity(next_tokens_embed,
                                                         self.unsafe_embed).unsqueeze(0)
                safety_penalty = (safety_penalty + 1) / 2
                safety_penalty = safety_penalty * (1 / (self.length ** self.length_beta))
                # print("length beta penalty", (1 / (self.length ** self.length_beta)))
                safety_penalty = torch.log(safety_penalty)

            else:
                safety_penalty = torch.tensor([[0] * self.num_beams], device="cuda")
            probability = (torch.log(next_tokens.values) + self.aggregated_probs) * (1 / (self.length ** 0.7))
            total_scores = (1 - self.safety_alpha) * probability - self.safety_alpha * safety_penalty
            return next_tokens.values, next_tokens.indices, safety_penalty, total_scores

    def get_default_scores(self, tokenizer):
        indices = torch.tensor([[tokenizer.eos_id] * self.num_beams], device="cuda")
        probability = torch.tensor([[self.beam_total_prob] * self.num_beams], device="cuda")
        safety_penalty = torch.tensor([[0] * self.num_beams], device="cuda")
        total_scores = self.score + safety_penalty
        return probability, indices, safety_penalty, total_scores

    def get_next_tokens(self, logits):
        probs = torch.softmax(logits[:, -1], dim=-1)
        # next_tokens = torch.topk(probs, self.num_beams, dim=-1)
        if self.top_p > 0:
            next_tokens = self.sample_top_p_multi(probs, self.num_beams)
        else:
            next_tokens = torch.topk(probs, self.num_beams, dim=-1)
        return next_tokens

    def get_token_embeddings_external(self, tokenizer, copy_token_tensor, next_tokens_embed, next_tokens, cur_pos, look_back):
        base_tokens_string = tokenizer.decode(copy_token_tensor[0, look_back:cur_pos].tolist())
        base_tokens_string = self.remove_padding_str(base_tokens_string)
        for beam in range(self.num_beams):
            copy_token_tensor[:, cur_pos] = next_tokens.indices[0][beam]
            beam_str = base_tokens_string + " " + tokenizer.decode(next_tokens.indices[0][beam].tolist())
            next_tokens_embed.append(self.embedder.embed(beam_str))
        next_tokens_embed = torch.stack(next_tokens_embed).squeeze(1)
        return next_tokens_embed

    def get_token_embeddings_internal(self, model, copy_token_tensor, next_tokens, cur_pos, prev_pos, look_back):
        stacked_token_tensor = torch.stack([copy_token_tensor * 1 for _ in range(self.num_beams)], dim=0).squeeze(1)
        # fill stacked tensor with next tokens
        stacked_token_tensor[:, cur_pos] = next_tokens.indices[0]
        _, outputs = model.forward_2(stacked_token_tensor[:, look_back:cur_pos + 1], prev_pos)
        next_tokens_embed = outputs[-1]
        return next_tokens_embed

    def get_safety_penalty(self, next_tokens_embed):
        next_tokens_embed_norm = F.normalize(next_tokens_embed, p=2, dim=1).to("cuda")

        safety_penalty = torch.max((1 - torch.matmul(next_tokens_embed_norm, self.unsafe_embed_norm.T)), dim=1).values
        safety_penalty = safety_penalty / 2
        safety_penalty = safety_penalty.unsqueeze(0)
        return safety_penalty

    def get_lookback(self, cur_pos):
        if self.last_x_words == 0:
            look_back = 0
        else:
            start_of_answer_index = cur_pos - self.length
            look_back = max(cur_pos - self.last_x_words, start_of_answer_index)
        return look_back
    def get_probability(self, next_tokens, next_tokens_embed):
        if self.aggregation_mode == "log":
            probability = (torch.log(next_tokens.values) + self.aggregated_probs)
        else:
            probability = (next_tokens.values * self.aggregated_probs)
        return probability

    def inference(self, prev_pos, cur_pos, model, warming_up=True, tokenizer=None):
        # check if indices fit in token_tensor, if not do nothing
        if self.token_tensor.shape[1] <= (cur_pos + 1):
            return self.get_default_scores(tokenizer)
        else:
            logits = model.forward(self.token_tensor[:, 0:cur_pos], prev_pos)
            next_tokens = self.get_next_tokens(logits)
            next_tokens_embed = []
            if self.safety_alpha > 0 and not warming_up:
                copy_token_tensor = copy.deepcopy(self.token_tensor)
                look_back = self.get_lookback(cur_pos)
                if self.embedder is not None:
                    next_tokens_embed = self.get_token_embeddings_external(tokenizer, copy_token_tensor, next_tokens_embed, next_tokens, cur_pos, look_back)
                else:
                   next_tokens_embed = self.get_token_embeddings_internal(model, copy_token_tensor, next_tokens, cur_pos, prev_pos, look_back)

                safety_penalty = self.get_safety_penalty(next_tokens_embed)
            else:
                safety_penalty = torch.tensor([[0] * self.num_beams], device="cuda")
            probability = self.get_probability(next_tokens, next_tokens_embed)
            total_scores = (1 - self.safety_alpha) * probability + self.safety_alpha * safety_penalty
            return next_tokens.values, next_tokens.indices, safety_penalty, total_scores


    def sample_top_p_multi(self, probs, num_beams):
        """
        Samples multiple tokens using the nucleus sampling (top-p sampling) algorithm.
        :param probs: the probability vector to sample tokens from
        :param num_beams: number of tokens to sample
        :return:
        """
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        probs_sum = torch.cumsum(probs_sort, dim=-1)
        mask = probs_sum - probs_sort > self.top_p
        probs_sort[mask] = 0.0
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        next_tokens = torch.multinomial(probs_sort, num_samples=num_beams)
        next_tokens = torch.gather(probs_idx, -1, next_tokens)
        next_tokens_probs = torch.gather(probs, -1, next_tokens)
        next_token_dict = SimpleNamespace(indices= next_tokens, values=next_tokens_probs)
        return next_token_dict

