"""Running visualization based on the cluster embeddings
"""

import os
# import umap
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import click
from typing import Text, Any, Dict, Tuple
import ujson as json
from sentence_transformers import SentenceTransformer


model = SentenceTransformer('paraphrase-MiniLM-L6-v2', device='cuda')


def _emb_visualization(item: Dict[Text, Any]) -> Tuple[Any, Text]:
    """ """
    clusters = item['clusters']
    item_to_embed = [cluster['cluster'] for cluster in clusters]
    reducer = TSNE(n_components=2, random_state=0, perplexity=min(30, len(item_to_embed) - 0.5))

    embeddings = model.encode(item_to_embed, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
    umap_embeddings = reducer.fit_transform(embeddings)

    # now visualize the embeddings
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    ax.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=10)
    # annotate the points with the cluster names
    for i, cluster in enumerate(clusters):
        ax.annotate(cluster['cluster'], (umap_embeddings[i, 0], umap_embeddings[i, 1]))

    plt.title(f"t-SNE visualization of \"{item['question']}\"")
    
    # calculate cosine similarity between the cluster embeddings
    sim = embeddings @ embeddings.T
    
    # create a similarity table in .md format
    sim_table = "|Cluster|"
    for i in clusters:
        sim_table += f"{i['cluster']}|"
    sim_table += "\n|---|"
    for i in clusters:
        sim_table += "---|"
    sim_table += "\n"
    for i, cluster in enumerate(clusters):
        sim_table += f"|{cluster['cluster']}|"
        for j in range(len(clusters)):
            sim_table += f"{sim[i, j]:.2f}|"
        sim_table += "\n"

    return fig, sim_table


@click.command()
@click.option("--input-path", type=str, help="Path to the input file.")
@click.option("--output-directory", type=str, help="Path to the output directory.")
@click.option("--index", type=int, help="Index of the cluster to visualize.", multiple=True)
def main(
    input_path,
    output_directory,
    index
):
    """ """
    
    with open(input_path, 'r', encoding='utf-8') as file_:
        data = [json.loads(line) for line in file_]
        selected_data = [data[i] for i in index]
        
    # now for each of the selected data item, we visualize the embeddings
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    for index, item in zip(index, selected_data):
        fig, sim_mat = _emb_visualization(item)
        with open(os.path.join(output_directory, f"{index}.md"), "w") as file_:
            file_.write(sim_mat)
        plt.savefig(os.path.join(output_directory, f"{index}.png"))
        

if __name__ == "__main__":
    main()