
import openai
from typing import List
import numpy as np
from numpy.linalg import norm
import os
import json
import threading
import fcntl
from concurrent.futures import ThreadPoolExecutor

import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='Get embeddings and compute similarity between completions')
    parser.add_argument('--temperature', '-t', type=float, required=True,
                       help='Temperature parameter for completions')
    parser.add_argument('--approach', type=str, default="ss",
                       help='approach used to generate completions')
    parser.add_argument('--max-workers', type=int, default=50,
                       help='Maximum number of concurrent workers (default: 50)')
    return parser.parse_args()

args = parse_args()
T = args.temperature


from utils import filter_seen_solutions, write_result

with open(f'final_data/{args.approach}_completions_temp_{T}.jsonl', 'r') as f:
    baseline_completions_0 = f.readlines()

# compute the max length of the completions
solutions = [json.loads(completion) for completion in baseline_completions_0]



def cosine_similarity(embeddings) -> float:
    a = embeddings[0]
    b = embeddings[1]
    a = np.array(a)
    b = np.array(b)
    return np.dot(a, b) / (norm(a) * norm(b))

def get_embedding(solution1, solution2, model="text-embedding-3-small") -> List[float]:
    client = openai.OpenAI()
    response = client.embeddings.create(
        input=[solution1, solution2],
        model=model
    )
    return response.data[0].embedding, response.data[1].embedding
def process_solution(solution):
    solution_prompt = solution['prompt']
    (solution_0, solution_1) = solution['completions']
    embeddings = get_embedding(solution_0,solution_1 )
    cosine_sim = cosine_similarity(embeddings)
    return {
        'prompt': solution_prompt,
        'similarity': cosine_sim
    }
    
def process_and_write(solution, output_file):
    result = process_solution(solution)
    write_result(result, output_file)
    

    
output_file = f'final_data/embeddings/{args.approach}_temp_{T}.jsonl'



solutions = filter_seen_solutions(solutions, output_file)

with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
    futures = [
        executor.submit(process_and_write, solution, output_file)
        for solution in solutions
    ]
    
    # Wait for all futures to complete
    for future in futures:
        future.result()
    