import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import fire


def plot_taxonomy_text_clusters(
    parquet_path,
    taxonomy_level="family",
    llm_type="Llama-2-7b-hf",
    s=None,
    labels=False,
):
    """
    parquet_path: string
    taxonomy_level: string (one of 'class', 'order', 'family', 'genus', 'species')
    llm_type: string (one of 'Llama-2-7b-hf', 'Llama-2-13b-hf', 'Llama-2-70b-hf')
    s: int
    labels: bool (whether to annotate the points with the name in the plot)
    """
    obs = pd.read_parquet(parquet_path)
    obs = obs.drop_duplicates(subset=[taxonomy_level])
    text_embeddings = np.load(
        f"data/text_embeddings/{taxonomy_level}_{llm_type}.npy", allow_pickle=True
    )
    data, color = [], []
    for i in text_embeddings[()].keys():
        embed = np.mean(text_embeddings[()][i], axis=(0, 1))
        data.append(embed)
        if taxonomy_level == "family":
            parent = "order"
        elif taxonomy_level == "genus":
            parent = "family"
        elif taxonomy_level == "species":
            parent = "genus"
        color.append(obs[obs[taxonomy_level] == i].iloc[0][parent])

    tsne = TSNE(n_components=2)
    proj = tsne.fit_transform(np.array(data))
    fig, ax = plt.subplots()
    if s is None:
        ax.scatter(
            proj[:, 0],
            proj[:, 1],
            c=np.unique(color, return_inverse=True)[1],
            cmap="jet",
        )
    else:
        ax.scatter(
            proj[:, 0],
            proj[:, 1],
            c=np.unique(color, return_inverse=True)[1],
            cmap="jet",
            s=s,
        )
    if labels:
        for i in range(proj.shape[0]):
            ax.annotate(
                list(text_embeddings[()].keys())[i],
                (proj[i, 0], proj[i, 1]),
                fontsize=3,
            )

    plt.savefig(f"result_plots/{taxonomy_level}_{llm_type}.png", dpi=300)


if __name__ == "__main__":
    fire.Fire(plot_taxonomy_text_clusters)
