import json
import pandas as pd
from utils import question_to_category
from typing import List   
def is_equal (A,B, category = None):
    if category == 'athlete':
        return A.split(" ")[-1].lower() == B.split(" ")[-1].lower()
    else:
        return A.lower() == B.lower()
    
def check_precision(ans: List[str], answer_gt: List[str], category = None): 
    if  any (is_equal(ans,a_gt, category = category) for a_gt in answer_gt): 
        return True
    else: 
        return False
    
    
def get_precision(ans: List[str], GT: List[List[str]], category: str = None): 
    flatten_GT = [item for sublist in GT for item in sublist]
    stats = [check_precision(an, flatten_GT, category = category) for an in ans]
    if len(ans) == 0: 
        return 1
    precision = (sum(stats)/ len(stats))
    return precision

def idx_to_prompt(idx):
    pd.read_csv('prompts_final.csv')
    # turn this into a dictionary idx: prompt
    prompts_df = pd.read_csv('prompts_final.csv')
    # there's a column called 'idx' and a column called 'prompt'
    prompts = prompts_df['prompt'].tolist()
    idxs = prompts_df['idx'].tolist()
    prompts = {idx: prompt for idx, prompt in zip(idxs, prompts)}
    return prompts[idx]

def get_answers_from_file(file):
    with open(file, 'r') as f:
        entries = f.readlines()
        entries = [json.loads(entry) for entry in entries]
    return entries

GT = "../CoverageQA.json"
with open (GT, 'r') as f:
    GT = json.load(f)
    
T = 1.0



def aggregate_precision(entries):
    precision = {'park': [], 'physic': [], 'country': [], 'instrument': [], 'chemical': [], 'athlete': [], 'total': []}
    idx = pd.read_csv('prompts_final.csv')['idx'].tolist()
    for i in idx:
        prompt = idx_to_prompt(i)
        cat = question_to_category(prompt)
        if cat != 'athelete':
            ans = [set(e['completions']) for e in entries if e['idx'] == i]
            precision_data = get_precision(ans[0], GT[idx_to_prompt(i)]['answers'], category = cat)
            precision[cat].append(precision_data)
            precision['total'].append(precision_data)
        
    # aggregate stats on ss_precision and b_precision
    precision = {k: sum(v)/len(v) for k, v in precision.items()}
    return precision

ablation = True

for cat in ['chemical', 'country', 'instrument', 'park', 'physic', 'athlete', 'total']:
    final_precision_vs_T = []
    anthropic_T = []
    abl_T = []
    for T in [0.0, 0.05,  0.1, 0.15, 0.25, 0.5, 0.75, 1.0, 1.1, 1.2, 1.25, 1.3, 1.5]:
        
        file_ss = f"./final_data/ss_completions_temp_{T}.jsonl"
        file_B = f"./final_data/baseline_completions_temp_{T}.jsonl"
        if T <= 1.0:
            file_anth = f"./final_data/baseline_completions_temp_{T}_anth.jsonl"
        if ablation:
            f_abl = f"./final_data/ss_completions_temp_{T}_abl.jsonl"
        entries_ss = get_answers_from_file(file_ss)
        entries_B = get_answers_from_file(file_B)
        if T <= 1.0:
            entries_anth = get_answers_from_file(file_anth)
        if ablation:
            entries_abl = get_answers_from_file(f_abl)
            
        ss_precision = aggregate_precision(entries_ss)
        bb_precision = aggregate_precision(entries_B)
        if T <= 1.0:
            anthropic_precision = aggregate_precision(entries_anth)
        if ablation:
            abl_precision = aggregate_precision(entries_abl)
        # average_precision_ss = sum(ss_precision.values())/len(ss_precision)
        # average_precision_B = sum(bb_precision.values())/len(bb_precision)
        average_precision_ss = ss_precision[cat]
        average_precision_B = bb_precision[cat]

        
        entry = [T, average_precision_ss, average_precision_B]
        if ablation: 
            average_precision_abl = abl_precision[cat]
            # entry.append(average_precision_abl)
            abl_T.append((T, average_precision_abl))
        
        final_precision_vs_T.append(entry)
        if T <= 1.0:
            anthropic_T.append((T, anthropic_precision[cat]))

    # Extract values for plotting
    temperatures = [x[0] for x in final_precision_vs_T]
    anth_t = [x[0] for x in anthropic_T]
    abl_t = [x[0] for x in abl_T]
    ss_precision = [x[1] for x in final_precision_vs_T]
    b_precision = [x[2] for x in final_precision_vs_T]
    anth_precision = [ x[1] for x in anthropic_T]
    if ablation: 
        abl_precision = [x[1] for x in abl_T]


    import matplotlib.pyplot as plt
    # Plot the data
    plt.figure(figsize=(10, 6))
    plt.plot(temperatures, ss_precision, label='SS Precision', marker='o')
    plt.plot(temperatures, b_precision, label='Baseline Precision', marker='o')
    plt.plot(anth_t, anth_precision, label='Anthropic Precision', marker='o')
    if ablation: 
        plt.plot(abl_t, abl_precision, label='SS Ablation Precision', marker='o')

    # Add labels, title, and legend
    plt.xlabel('Temperature (T)')
    plt.ylabel('Average Precision')
    plt.title('Average Precision vs Temperature')
    plt.legend()
    plt.grid(True)

    abl_tag = "_abl" if ablation else ""
    # Show the plot
    plt.savefig(f'./plots/precision_vs_T_{cat}{abl_tag}.png')

    # close plt
    plt.close()



        
