#!/usr/bin/env python
# coding: utf-8

# In[1]:


import os
if "models" not in os.listdir("."):
    os.chdir("../..")


# In[24]:


import json

with open("task_pair_metrics_3.jsonl") as f:
    lines = f.readlines()
    results = [json.loads(line) for line in lines]


# In[25]:


first_task = "antonyms"

task_results = [result for result in results if result["task_pair"][0] == first_task]

len(task_results)


# In[26]:


from sprint.task_vector_utils import load_tasks, ICLRunner

# Load tasks
tasks = load_tasks()


# In[27]:


task_names = list(tasks.keys())


# In[28]:


result_dict = {tuple(x["task_pair"]):x for x in results}


# In[29]:


faithfulness_task = result_dict[(first_task, first_task)]["faithfullness"]
n_nodes_task = result_dict[(first_task, first_task)]["n_nodes_counts"]



# In[30]:


for second_task in task_names:
    try:
        orig_metric = result_dict[(second_task, second_task)]["orig_metric"]
        zero_metric = result_dict[(second_task, second_task)]["zero_metric"]

        print(
            f"{second_task}: orig_metric: {orig_metric}, zero_metric: {zero_metric}"
        )
    except KeyError:
        print(f"No results for {second_task}")


# In[ ]:





# In[32]:


import plotly.express as px

first_task = "antonyms"

faithfulness_task = result_dict[(first_task, first_task)]["faithfullness"]
n_nodes_task = result_dict[(first_task, first_task)]["n_nodes_counts"]


fig = px.line(x=[max(n_nodes_task) - x for x in n_nodes_task], y=faithfulness_task)

for second_task in task_names:
    # if second_task == "person_profession":
    #     continue
    faithfulness_second = result_dict[(first_task, second_task)]["faithfullness"]
    n_nodes_second = result_dict[(first_task, second_task)]["n_nodes_counts"]
    fig.add_scatter(x=[max(n_nodes_second) - x for x in n_nodes_second], y=faithfulness_second, name=second_task)

# fig.add_scatter(x=[max(n_nodes_second) - x for x in n_nodes_second], y=faithfulness_second)

fig.update_xaxes(title="Number of nodes")
fig.update_yaxes(title="Faithfullness")

fig


# In[35]:


import numpy as np 

task_names = [x for x in task_names if x not in ["person_profession", " "]]
heatmap_data = np.zeros((len(task_names), len(task_names)))


for i, first_task in enumerate(task_names):
    _faith = result_dict[(first_task, first_task)]["faithfullness"]
    idx = [i for i, x in enumerate(_faith) if x < 0.5][0]
    for j, second_task in enumerate(task_names):
        if first_task == second_task:
            faith = _faith
        else:
            faith = result_dict[(first_task, second_task)]["faithfullness"]
        heatmap_data[i, j] = faith[idx]

fig = px.imshow(heatmap_data, x=task_names, y=task_names)

fig

