import os
from openai import OpenAI
import numpy as np
import json
import argparse

parser = argparse.ArgumentParser(description="Process keypoint embeddings")
parser.add_argument("--paper_directory", type=str, help="Path to the paper directory")
parser.add_argument("--embedding_model", type=str, default="text-embedding-3-small", help="Embedding model to use")

args = parser.parse_args()

openai_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=openai_key)

embedding_model = args.embedding_model
paper_directory = args.paper_directory

def parse_key_points(file_path):
	keypoints = []
	with open(file_path, 'r') as file:
		for line in file:
			keypoint = line.strip()
			if keypoint == "":
				continue
			keypoints.append(keypoint)
	return keypoints

def get_embedding(keypoints_list):
	# get embedding from OpenAI using text-embedding-3-small

	response = client.embeddings.create(
        model=embedding_model,
        input=keypoints_list
    )
	
	keypoints_embeddings = [item.embedding for item in response.data]

	return keypoints_embeddings

# traverse all files ending with _keypoints.txt recursively in the given directory


embeddings = []
metadata = []

for root, dirs, files in os.walk(paper_directory):
	for filename in files:
		if filename.endswith("_keypoints.txt"):
			file_path = os.path.join(root, filename)
			print(f"Processing file: {file_path}")
			# from root extract level, from filename extract reviewer_id
			level = None
			for x in range(1,5):
				if f"level{x}" in root:
					level = f"level{x}"
			if level is None:
				level = "level5"
			
			reviewer_id = filename.split("_")[0]
			keypoints = parse_key_points(file_path)

			metadata.extend([(level, reviewer_id, kp) for kp in keypoints])
			embeddings.extend(get_embedding(keypoints))

embeddings = np.array(embeddings)

np.save(os.path.join(paper_directory, "keypoint_embeddings.npy"), embeddings)
with open(os.path.join(paper_directory, "keypoint_metadata.json"), "w") as f:
	json.dump(metadata, f, indent=4)
