# for entry in jsonl, load and compute average similarity

import json
import math

import matplotlib.pyplot as plt
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='Get embeddings and compute similarity between completions')
    return parser.parse_args()

args = parse_args()


options_baseline = [0.0,  0.1, 0.15, 0.25, 0.5, 0.75, 1.0, 1.1, 1.2, 1.25, 1.3]
options_ss = [0.0, 0.1, 0.15, 0.25, 0.5, 0.75, 1.0, 1.1, 1.2, 1.25, 1.3]


def get_results(approach):
    options = None
    if approach == "ss": 
        options = options_ss
    elif approach == "baseline":
        options = options_baseline
    else: 
        raise NotImplementedError(f"Approach {approach} not implemented")
    results = []
    for option in options:
        with open(f'final_data/embeddings/{approach}_temp_{option}.jsonl', 'r') as f:
            entries = f.readlines()

        with open(f'final_data/quality_judge/{approach}_temp_{option}.jsonl', 'r') as f:
            entries_q = f.readlines()

        entries = [json.loads(entry) for entry in entries]
        entries_q = [json.loads(entry) for entry in entries_q]
        # compute averages
        avg_similarity = sum(entry['similarity'] for entry in entries) / len(entries)
        avg_spelling_error = sum(entry_q['spelling_error'] for entry_q in entries_q) / len(entries_q)
        # avg_formatting_error = sum(entry_q['formatting_error'] for entry_q in entries_q) / len(entries_q)
        avg_formatting_error = sum(entry_q['fully_formatted'] for entry_q in entries_q) / len(entries_q)

        
        # print results
        results.append((option,avg_similarity, avg_spelling_error, avg_formatting_error))
        if option == options[0]:
            print("similarity, spelling_error, formatting_error")
        print(f"Average similarity for {option}: {avg_similarity:.2f}, {avg_spelling_error:.2f}, {avg_formatting_error:.2f}")
    return results

results_ss = get_results("ss")
results_baseline = get_results("baseline")

def convert_to_distance(sim):
    return math.sqrt(2*(1 - sim))


orange = '#F07823'
blue = '#5DBAE8'
# plot precision recall curves
plot = True
if plot: 
    
    
    precision_ss = [1 - avg_formatting_error for _,_, _, avg_formatting_error in results_ss]
    recall_ss = [ convert_to_distance(avg_similarity) for _, avg_similarity, _, _ in results_ss]
    precision_baseline = [1 - avg_formatting_error for _,_, _, avg_formatting_error in results_baseline]
    recall_baseline = [ convert_to_distance(avg_similarity) for _, avg_similarity, _, _ in results_baseline]
    plt.plot(recall_baseline, precision_baseline, 'o-', color=blue, label='GPT-4o')
    plt.plot(recall_ss, precision_ss, 'o-', color=orange, label='SimpleStrat (GPT-4o)')

    # remove the top line and right line of graph
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    # callout the datapoint 
    plt.annotate('T=0', (recall_baseline[0], precision_baseline[0]), textcoords="offset points", xytext=(0,10), ha='center', fontsize=10, color='black', )
    plt.annotate('T=0', (recall_ss[0], precision_ss[0]), textcoords="offset points", xytext=(0,10), ha='center', fontsize=10, color='black')
    plt.annotate('T=1', (recall_baseline[6], precision_baseline[6]), textcoords="offset points", xytext=(-20,-20), ha='center', fontsize=10, color='black', arrowprops=dict(arrowstyle='-|>', facecolor='black'))
    plt.annotate('T=1', (recall_ss[6], precision_ss[6]), textcoords="offset points", xytext=(20,-20), ha='center', fontsize=10, color='black', arrowprops=dict(arrowstyle='->', facecolor='black'))

    plt.annotate('T=1.3', (recall_baseline[-1], precision_baseline[-1]), textcoords="offset points", xytext=(-30,0), ha='center', fontsize=10, color='black')#, arrowprops=dict(arrowstyle='->', facecolor='black'))
    # plt.annotate('T=1.3', (recall_ss[-1], precision_ss[-1]), textcoords="offset points", xytext=(20,0), ha='center', fontsize=10, color='black')#, arrowprops=dict(arrowstyle='->', facecolor='black'))
    
    
    plt.xlabel('Diversity - Cosine Distance')
    plt.ylabel('Quality - Format Adherence')
    # plt.title('Quality-Diversity Tradeoff')
    plt.legend()
    plt.savefig(f'plots/writing_prompts_formatting_recall.pdf')
    plt.close()
    
        
    precision_ss = [1 - avg_spelling_error for _,_, avg_spelling_error, _ in results_ss]
    recall_ss = [ convert_to_distance(avg_similarity) for _, avg_similarity, _, _ in results_ss]
    precision_baseline = [1 - avg_spelling_error for _,_, avg_spelling_error, _ in results_baseline]
    # plt.text(recall_baseline[0], precision_baseline[0], 'temp=0', fontsize=10, color='black')
    recall_baseline = [ convert_to_distance(avg_similarity) for _, avg_similarity, _, _ in results_baseline]
    plt.plot(recall_baseline, precision_baseline, 'o-', color=blue, label='GPT-4o')
    plt.plot(recall_ss, precision_ss, 'o-', color=orange, label='SimpleStrat (GPT-4o)')
    plt.xlabel('Diversity - Cosine Distance')
    plt.ylabel('Quality - (1 - Spelling Error Rate)')
    plt.title('Quality-Diversity Tradeoff')
    plt.legend()
    plt.savefig(f'plots/writing_prompts_spelling_recall.pdf')
    plt.close()
    
    options_ss = options_ss[:-2]
    options_baseline = options_baseline[:-2]
    results_ss = results_ss[:-2]
    results_baseline = results_baseline[:-2]
    
    # plot x = temperature, y = avg_similarity, y = avg_spelling_error, y = avg_formatting_error
    # plt.plot(options_ss, [convert_to_distance(avg_similarity) for _, avg_similarity, _, _ in results_baseline], 'o-', color=blue, label='Baseline')
    # plt.plot(options_ss, [convert_to_distance(avg_similarity) for _, avg_similarity, _, _ in results_ss], 'o-', color=orange, label='SimpleStrat')
    # plt.plot(options_ss, [avg_spelling_error for _, _, avg_spelling_error, _ in results_ss], 'o-', label='SimpleStrat')
    plt.plot(options_ss, [1- avg_formatting_error for _, _, _, avg_formatting_error in results_baseline], 'o-', color=blue, label='Baseline')
    plt.plot(options_ss, [1- avg_formatting_error for _, _, _, avg_formatting_error in results_ss], 'o-', color=orange, label='SimpleStrat')
    plt.xlabel('Temperature')
    plt.ylabel('Format Adherence (% with No Errors)')
    plt.title('Formatting Error vs Temperature')
    plt.legend()
    plt.savefig(f'plots/writing_prompts_temperature_formatting.pdf')
    plt.close()


    
    plt.plot(options_ss, [convert_to_distance(avg_similarity) for _, avg_similarity, _, _ in results_baseline], 'o-', color=blue, label='Baseline')
    plt.plot(options_ss, [convert_to_distance(avg_similarity) for _, avg_similarity, _, _ in results_ss], 'o-', color=orange, label='SimpleStrat')
    # plt.plot(options_ss, [avg_spelling_error for _, _, avg_spelling_error, _ in results_ss], 'o-', label='SimpleStrat')
    # plt.plot(options_ss, [1- avg_spelling_error for _, _, avg_spelling_error,  _ in results_baseline], 'o-', color=blue, label='Baseline')
    # plt.plot(options_ss, [1- avg_spelling_error for _, _, avg_spelling_error,  _ in results_ss], 'o-', color=orange, label='SimpleStrat')
    plt.xlabel('Temperature')
    plt.ylabel('Diversity - Cosine Distance')
    plt.title('Temperature-Diversity Tradeoff')
    plt.legend()
    plt.savefig(f'plots/writing_prompts_temperature_spelling.pdf')
    plt.close()