import json
import pandas as pd
from utils import question_to_category
from typing import List   
def is_equal (A,B, category = None):
    if category == 'athlete':
        return A.split(" ")[-1].lower() == B.split(" ")[-1].lower()
    else:
        return A.lower() == B.lower()
    
def check_recall(ans: List[str], answer_equiv_gt: List[str], category = None): 
    for answer in ans: 
        if  any (is_equal(a,answer) for a in answer_equiv_gt): 
            return True
    else: 
        return False
    
    
def get_recall(ans: List[str], GT: List[List[str]], category: str = None): 
    stats = [check_recall(ans, answer_group_gt, category = category) for answer_group_gt in GT]
    recall = (sum(stats)/ len(stats))
    return recall

def idx_to_prompt(idx):
    pd.read_csv('prompts_final.csv')
    # turn this into a dictionary idx: prompt
    prompts_df = pd.read_csv('prompts_final.csv')
    # there's a column called 'idx' and a column called 'prompt'
    prompts = prompts_df['prompt'].tolist()
    idxs = prompts_df['idx'].tolist()
    prompts = {idx: prompt for idx, prompt in zip(idxs, prompts)}
    return prompts[idx]

def get_answers_from_file(file):
    with open(file, 'r') as f:
        entries = f.readlines()
        entries = [json.loads(entry) for entry in entries]
    return entries

GT = "../CoverageQA.json"
with open (GT, 'r') as f:
    GT = json.load(f)
    
T = 1.0



def aggregate_recall(entries):
    recall = {'park': [], 'physic': [], 'country': [], 'instrument': [], 'chemical': [], 'athlete': [], 'total': []}
    idx = pd.read_csv('prompts_final.csv')['idx'].tolist()
    for i in idx:
        prompt = idx_to_prompt(i)
        cat = question_to_category(prompt)
        if cat != 'athelete':
            ans = [set(e['completions']) for e in entries if e['idx'] == i]
            recall_data = get_recall(ans[0], GT[idx_to_prompt(i)]['answers'], category = cat)
            recall[cat].append(recall_data)
            recall['total'].append(recall_data)
        
    # aggregate stats on ss_recall and b_recall
    recall = {k: sum(v)/len(v) for k, v in recall.items()}
    return recall

ablation = True

for cat in ['chemical', 'country', 'instrument', 'park', 'physic', 'athlete', 'total']:
    final_recall_vs_T = []
    anthropic_T = []
    abl_T = []
    for T in [0.0, 0.05,  0.1, 0.15, 0.25, 0.5, 0.75, 1.0, 1.1, 1.2, 1.25, 1.3, 1.5]:
        
        file_ss = f"./final_data/ss_completions_temp_{T}.jsonl"
        file_B = f"./final_data/baseline_completions_temp_{T}.jsonl"
        if T <= 1.0:
            file_anth = f"./final_data/baseline_completions_temp_{T}_anth.jsonl"
        if ablation:
            f_abl = f"./final_data/ss_completions_temp_{T}_abl.jsonl"
        entries_ss = get_answers_from_file(file_ss)
        entries_B = get_answers_from_file(file_B)
        if T <= 1.0:
            entries_anth = get_answers_from_file(file_anth)
        if ablation:
            entries_abl = get_answers_from_file(f_abl)
            
        ss_recall = aggregate_recall(entries_ss)
        bb_recall = aggregate_recall(entries_B)
        if T <= 1.0:
            anthropic_recall = aggregate_recall(entries_anth)
        if ablation:
            abl_recall = aggregate_recall(entries_abl)
        # average_recall_ss = sum(ss_recall.values())/len(ss_recall)
        # average_recall_B = sum(bb_recall.values())/len(bb_recall)
        average_recall_ss = ss_recall[cat]
        average_recall_B = bb_recall[cat]

        
        entry = [T, average_recall_ss, average_recall_B]
        if ablation: 
            average_recall_abl = abl_recall[cat]
            # entry.append(average_recall_abl)
            abl_T.append((T, average_recall_abl))
        
        final_recall_vs_T.append(entry)
        if T <= 1.0:
            anthropic_T.append((T, anthropic_recall[cat]))
        
    import matplotlib.pyplot as plt

    # Extract values for plotting
    temperatures = [x[0] for x in final_recall_vs_T]
    anth_t = [x[0] for x in anthropic_T]
    abl_t = [x[0] for x in abl_T]
    ss_recalls = [x[1] for x in final_recall_vs_T]
    b_recalls = [x[2] for x in final_recall_vs_T]
    anth_recalls = [ x[1] for x in anthropic_T]
    if ablation: 
        abl_recalls = [x[1] for x in abl_T]
    

    # Plot the data
    plt.figure(figsize=(10, 6))
    plt.plot(temperatures, ss_recalls, label='SS Recall', marker='o')
    plt.plot(temperatures, b_recalls, label='Baseline Recall', marker='o')
    plt.plot(anth_t, anth_recalls, label='Anthropic Recall', marker='o')
    if ablation: 
        plt.plot(abl_t, abl_recalls, label='SS Ablation Recall', marker='o')

    # Add labels, title, and legend
    plt.xlabel('Temperature (T)')
    plt.ylabel('Average Recall')
    plt.title('Average Recall vs Temperature')
    plt.legend()
    plt.grid(True)

    abl = "_abl" if ablation else ""
    # Show the plot
    plt.savefig(f'./plots/recall_vs_T_{cat}{abl}.png')

    # close plt
    plt.close()



        
