
from filter.Embedder import MiniEmbedder, MpnetBaseEmbedder, MxEmbedder
from filter.Visualizer import PCA_TSNE
import torch
import pandas as pd
from filter import Generator
from llama_test import embed
from Params import *
import os

def test_embedder(embedder: MiniEmbedder):
    new_horizons = "New Horizons is the fastest space craft in the world, reaching a top speed of about 36,000 miles per hour (58,000 km/h)."
    sent_1 = "As a helpful AI assistant, I'd be happy to answer your question! 🚀 \
                    The fastest spacecraft in the world is currently the New"
    sent_2 = "As a helpful AI assistant, I'd be happy to answer your question! 🚀 \
                    The fastest spacecraft in the world is currently the Voy"
    sentence_1_embed = embedder.embed(new_horizons)
    sentence_2_embed = embedder.embed(sent_1)
    neg_sentence_embed = embedder.embed(sent_2)
    sim_1_2 = torch.cosine_similarity(sentence_1_embed, sentence_2_embed)
    sim_1_neg = torch.cosine_similarity(sentence_1_embed, neg_sentence_embed)
    sim_2_neg = torch.cosine_similarity(sentence_2_embed, neg_sentence_embed)
    print("sentence_1: ", new_horizons)
    print("sentence_2: ", sent_1)
    print("neg_sentence: ", sent_2)
    print(f"Similarity between 1 and 2: {sim_1_2.item()}")
    print(f"Similarity between 1 and neg: {sim_1_neg.item()}")
    print(f"Similarity between 2 and neg: {sim_2_neg.item()}")

    exit(0)

def check_embeddings(embed_pos_1, embed_pos_2, embed_neg, load_file=True, name_1="pos_1", name_2="pos_2", name_neg="neg"):
    if load_file:
        tensor_pos_1 = torch.load(embed_pos_1)[-1][-2].reshape(1, -1)
        tensor_pos_2 = torch.load(embed_pos_2)[-1][-2].reshape(1, -1)
        tensor_neg = torch.load(embed_neg)[-1][-2].reshape(1, -1)
    else:
        tensor_pos_1 = embed_pos_1[-1][-1].reshape(1, -1)
        tensor_pos_2 = embed_pos_2[-1][-1].reshape(1, -1)
        tensor_neg = embed_neg[-1][-1].reshape(1, -1)
    sim_1_2 = torch.cosine_similarity(tensor_pos_1, tensor_pos_2)
    sim_1_neg = torch.cosine_similarity(tensor_pos_1, tensor_neg)
    sim_2_neg = torch.cosine_similarity(tensor_pos_2, tensor_neg)
    sum_1_2 = torch.sum(torch.abs(tensor_pos_1 - tensor_pos_2))
    sum_1_neg = torch.sum(torch.abs(tensor_pos_1 - tensor_neg))
    sum_2_neg = torch.sum(torch.abs(tensor_pos_2 - tensor_neg))
    if load_file:
        print_name_pos_1 = embed_pos_1.split("/")[-1]
        print_name_pos_2 = embed_pos_2.split("/")[-1]
        print_name_neg = embed_neg.split("/")[-1]
    else:
        print_name_pos_1 = name_1
        print_name_pos_2 = name_2
        print_name_neg = name_neg
    print(f"Similarity between {print_name_pos_1} and {print_name_pos_2}: {sim_1_2.item()}")
    print(f"Similarity between {print_name_pos_1} and {print_name_neg}: {sim_1_neg.item()}")
    print(f"Similarity between {print_name_pos_2} and {print_name_neg}: {sim_2_neg.item()}")
    print(f"Sum of {print_name_pos_1} and {print_name_pos_2}: {sum_1_2}")
    print(f"Sum of {print_name_pos_1} and {print_name_neg}: {sum_1_neg.item()}")
    print(f"Sum of {print_name_pos_2} and {print_name_neg}: {sum_2_neg.item()}")


def test_visualizer(model_dir=None, tokenizer_path=None, use_llama=True):
    if use_llama:
        embedder = Generator.build(
            ckpt_dir=model_dir,
            tokenizer_path=tokenizer_path,
            max_seq_len=100,
            max_batch_size=8,
            clean_llama=True
        )
    else:
        embedder = MxEmbedder()
    prompt_list = NEGATIVE_PROMPT
    tensor_list_memory = []
    output_str = ""
    for dirpath, dirnames, filenames in os.walk("prompts/tsne_journey_6"):
        for filename in filenames:
            tensor = torch.load(os.path.join(dirpath, filename)).squeeze(0)
            filename_display = filename.replace("embedding_memory_", "").replace(".pt", "")
            word = filename_display.split("_")[-1]
            output_str += word + " "
            tensor_list_memory.append((tensor, filename_display, 0))
    for dirpath, dirnames, filenames in os.walk("prompts/tsne_journey_5"):
        for filename in filenames:
            tensor = torch.load(os.path.join(dirpath, filename)).squeeze(0)
            filename_display = filename.replace("embedding_memory_", "").replace(".pt", "")
            word = filename_display.split("_")[-1]
            output_str += word + " "
            tensor_list_memory.append((tensor, filename_display, 1))
    print(output_str)
    embedding_list = []
    tensor_list = []
    label_list = []
    tag_list = []
    for prompt in prompt_list:
        save_name = prompt.replace(" ", "_")
        if use_llama:
            embedding = embed(generator=embedder, prompt=prompt, use_saved=False, save_name=save_name,
                              directory="tsne_embeds", use_first_embed=False, get_filename=False)
            embedding = embedding[-1, -1, :].cpu()
        else:
            embedding = embedder.embed(prompt)
            embedding = torch.tensor(embedding).squeeze(0)
        tensor_list.append(embedding)
        label_list.append(prompt)
        tag_list.append(2)
    for tensor, filename, method in tensor_list_memory:
        tensor_list.append(tensor)
        label_list.append(filename)
        tag_list.append(method)
    # for embedding in embedding_list:
    #     tensor_list.append(torch.load(embedding)[:, -1, :])
    #     label_list.append(embedding.split("/")[-1].split(".")[0])
    pca_matrix = torch.stack(tensor_list)
    pca_matrix = pd.DataFrame(pca_matrix)
    # pca_matrix = torch.stack(tensor_list).squeeze(1)
    visualizer = PCA_TSNE(2, perplexity=50, n_iter=400)
    pca = visualizer.fit_transform(pca_matrix)
    tsne = visualizer.fit_transform_tsne(pca)
    visualizer.plot(tsne, axis_1='tsne_0', axis_2='tsne_1', fig_size=(40, 40),labels=label_list, tags=tag_list, font_size=14)

