import sys, os, pickle
import torch
from src.entity.embedder.STEmbedder import STEmbedder
from src.entity.embedder.GPTEmbedder import GPTEmbedder

from src.entity.datasets.Dataset import Dataset
from src.entity.datasets.GPQA import GPQA

import matplotlib.pyplot as plt


import argparse

EMBEDDER_TYPES = {"gpt", "st"}
MODES = {"question", "question", "question + answer"}

def start_embedding(dataset:Dataset, emb_type, mode, output_dir):
    
    problems = dataset.problems
    if emb_type == "gpt":
        embedder = GPTEmbedder()
    elif emb_type == "st":
        pass
        # embedder = STEmbedder()
    else:
        raise ValueError("Invalid embedder type. Available types are: {}".format(", ".join(sorted(EMBEDDER_TYPES))))
    embedder.create_embeddings(problems=problems, mode=mode, output_dir=output_dir)
    

def calculate_similarity(embeddings_dir:str):
    embeddings = []
    for file in os.listdir(embeddings_dir):
        with open(os.path.join(embeddings_dir, file), "rb") as f:
            embedding = pickle.load(f)
        embeddings.append(embedding)
    embeddings = torch.stack(embeddings).squeeze(1)

    # Calculate cosine similarity between all embeddings
    similarity_matrix = torch.matmul(embeddings, embeddings.T)


    # Output similarity matrix to image
    plt.imshow(similarity_matrix, cmap='viridis')
    plt.colorbar()
    plt.savefig('similarity_matrix.png')
    return similarity_matrix

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate embeddings for text files using either a GPT model or a sentence transformer.")
    parser.add_argument("-m", "--mode", type=str, choices=MODES, default='question', help="Mode of the dataset to generate embeddings for.")
    parser.add_argument("-s", "--size", type=str, default=10000, help="Size of the dataset to generate embeddings for.")
    parser.add_argument("-o", "--output_dir", type=str, required=True, help="Path where the generated embeddings should be saved.")
    parser.add_argument("--emb_type", type=str, choices=EMBEDDER_TYPES, default="gpt",
                        help="Type of embedder to use. Available options: {}".format(", ".join(sorted(EMBEDDER_TYPES))))
    parser.add_argument("--similarity", action="store_true", help="Calculate similarity between embeddings.")

    args = parser.parse_args()
    if args.similarity:
        embeddings_dir = args.output_dir
        calculate_similarity(embeddings_dir)
    else:
        dataset = GPQA(size=args.size)
        start_embedding(dataset=dataset, emb_type=args.emb_type, mode=args.mode, output_dir=args.output_dir)