import json
import numpy as np
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt

def plot_tsne_embeddings(
    embeddings_list,
    labels,
    colors,
    markers,
    tsne_params=None,
    title=None
):
    if tsne_params is None:
        tsne_params = {'n_components': 2, 'random_state': 42, 'perplexity': 40, 'init': 'pca', 'learning_rate': 'auto', 'n_iter': 1000}

    plt.figure(figsize=(8, 6))
    for emb, lbl, col, m in zip(embeddings_list, labels, colors, markers):
        tsne = TSNE(**tsne_params)
        emb_2d = tsne.fit_transform(emb)
        plt.scatter(
            emb_2d[:, 0],
            emb_2d[:, 1],
            c=col,
            marker=m,
            label=lbl,
            alpha=0.7,
            edgecolors='w',
            linewidths=0.5
        )

    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    if title:
        plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig("mbpp_tsne_plot.png", dpi=300, bbox_inches="tight")
    print("t-SNE plot saved as 'tsne_plot.png'")



# === Load prompts ===
def load_prompts_llama(file_path):
    with open(file_path, "r") as f:
        data = json.load(f)
    return [item["generated_instruction"] for item in data]

def load_prompts_gpt(file_path):
    with open(file_path, "r") as f:
        data = json.load(f)
    questions = []
    for item in data:
        instr = item.get("instruction", "")
        # if "Question:" in instr:
        #     question = instr.split("Question:")[-1].split("\nAnswer:")[0].strip()
        questions.append(instr)
    return questions

llama_prompts = load_prompts_llama("synthetic_mbpp_prompts.json")
gpt_prompts = load_prompts_gpt("synthetic_gpt.json")

# === Embed prompts ===
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
# model = SentenceTransformer("bert-base-nli-mean-tokens")
llama_embeddings = model.encode(llama_prompts, convert_to_numpy=True)
gpt_embeddings = model.encode(gpt_prompts, convert_to_numpy=True)

# === Plot ===
plot_tsne_embeddings(
    embeddings_list=[llama_embeddings, gpt_embeddings],
    labels=["LLaMA2-generated", "GPT-generated"],
    colors=["red", "blue"],
    markers=["o", "s"],
    title="t-SNE: LLaMA2 vs GPT Generated Prompts"
)

from sklearn.metrics.pairwise import cosine_distances
llama_diversity = np.mean(cosine_distances(llama_embeddings))
gpt_diversity = np.mean(cosine_distances(gpt_embeddings))
print(f"LLaMA2 Diversity: {llama_diversity:.4f}")
print(f"GPT Diversity: {gpt_diversity:.4f}")


