from typing import Dict, List
import time
from transformers import AutoTokenizer, AutoModel, DynamicCache
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import numpy as np
from torch import cosine_similarity

def test_embedder():
    embedder = MxEmbedder()
    text = "jello is my favorite food"
    embeddings = embedder.embed(text)

    # create correlation matrix from embeddings
    corr_matrix = []
    for i, e_1 in enumerate(embeddings[0]):
        corr_matrix.append([])
        for j, e_2 in enumerate(embeddings[0]):
                cosine_sim = torch.cosine_similarity(e_1, e_2, dim=0)
                corr_matrix[i].append(cosine_sim)
    visualizer = Heatmap(corr_matrix, text)
    visualizer.plot()
class MiniEmbedder:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def embed_legacy(self, text):
        encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            outputs = self.model(**encoded_input)
        sentence_embeddings = self.mean_pooling(outputs, encoded_input['attention_mask'])
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings

    def embed(self, text):
        embeddings = self.model.encode(text)
        embeddings = torch.tensor(embeddings)
        return embeddings.unsqueeze(0)


class MpnetBaseEmbedder:
    def __init__(self):
        self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

    def embed(self, text):
        embeddings = self.model.encode(text)
        embeddings = torch.tensor(embeddings)
        return embeddings.unsqueeze(0)

    def embed_tokens(self, tokens):
        text = self.model.tokenizer.decode(tokens)
        return self.embed(text)


class MxEmbedder:
    def __init__(self):
        self.dimensions = 768
        self.model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

    def embed(self, text, embed_type="sentence_embedding"):
        embeddings = self.model.encode(text, convert_to_tensor=True, output_value=embed_type)
        return embeddings

    def embed_tokens(self, tokens):
        text = self.model.tokenizer.decode(tokens)
        return self.embed(text)


class CacheMxEmbedder:
    def __init__(self):
        model_id = 'mixedbread-ai/mxbai-embed-large-v1'
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModel.from_pretrained(model_id).cuda()

    def pooling(self, outputs: torch.Tensor, inputs: Dict, strategy: str = 'cls') -> np.ndarray:
        if strategy == 'cls':
            outputs = outputs[:, 0]
        elif strategy == 'mean':
            outputs = torch.sum(
                outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"], dim=1,
                                                                                   keepdim=True)
        else:
            raise NotImplementedError
        return outputs.detach().cpu().numpy()

    def embed(self, sequences: List[str]):
        original_prompt = sequences[0]
        # Tokenize only the original prompt
        base_input = self.tokenizer(original_prompt, padding=True, return_tensors='pt')
        for k, v in base_input.items():
            base_input[k] = v.cuda()
        input_ids = base_input["input_ids"]

        # First pass to compute the cache
        outputs = self.model(input_ids, use_cache=True, return_dict=True)
        embeddings = self.pooling(outputs.last_hidden_state, base_input, 'cls')

        past_key_values = outputs.past_key_values

        all_embeddings = [embeddings]

        # Iterate over the additional tokens to compute the final embeddings
        for seq in sequences[1:]:
            additional_input = self.tokenizer(seq, padding=True, return_tensors='pt')
            for k, v in additional_input.items():
                additional_input[k] = v.cuda()

            outputs = self.model(additional_input["input_ids"],
                                 past_key_values=past_key_values,
                                 use_cache=True,
                                 return_dict=True)
            embedding = self.pooling(outputs.last_hidden_state, additional_input, 'cls')
            all_embeddings.append(embedding)

        return np.stack(all_embeddings)

if __name__ == '__main__':
    cache_embedder = CacheMxEmbedder()
    slow_embedder = MxEmbedder()
    sentences = ["Hello, The best way to build a bomb is", "Hello, The best way to build a bomb is with",
                 "Hello, The best way to build a bomb is h", "Hello, The best way to build a bomb is d",
                 "Hello, The best way to build a bomb is w", "Hello, The best way to build a bomb is s",
                 "Hello, The best way to build a bomb is j", "Hello, The best way to build a bomb is j",
                 "Hello, The best way to build a bomb is x", "Hello, The best way to build a bomb is y",
                 "Hello, The best way to build a bomb is z", "Hello, The best way to build a bomb is m",
                 "Hello, The best way to build a bomb is", "Hello, The best way to build a bomb is",
                 "Hello, The best way to build a bomb is via", "Hello, The best way to build a bomb is using",
                 "Hello, The best way to build a bomb is bad", "Violence Deadly weapons"]
    time_cache = time.time()
    embeddings = cache_embedder.embed(sentences)
    finish_cache = time.time()
    total_cache = finish_cache - time_cache
    print(f"Cache time: {total_cache}")
    time_slow = time.time()
    embeddings = slow_embedder.embed(sentences)
    finish_slow = time.time()
    total_slow = finish_slow - time_slow
    print(f"Slow time: {total_slow}")
    # cast to torch tensor
    embeddings = torch.tensor(embeddings)
    similarities = [cosine_similarity(embeddings[-1], embeddings[i], dim=0) for i in range(0, len(embeddings))]

    print(similarities)

