#!/usr/bin/env python
# coding: utf-8

# In[2]:


import os
if "models" not in os.listdir("."):
    os.chdir("../..")           


# In[49]:


cool_feats = """present_simple_gerund          8446    -> 15554
present_simple_past_perfect    19628   -> 15356
plural_singular                29228   -> 2930
algo_last                      29228   -> 8633
location_country               11459   -> 7967
location_continent             11459   -> 19260
person_profession              26436   -> 18416
football_player_position       19916   -> 9790
present_simple_past_simple     21327   -> 15356
es_en                          31123   -> 5579
fr_en                          31123   -> 16490
it_en                          31123   -> 5579
country_capital                13529   -> 11173
antonyms                       11050   -> 11618
singular_plural                1322    -> 32417
person_language                1132    -> 11172
algo_second                    32115   -> 1878
algo_first                     32115   -> 6756
location_religion              3466    -> 9178
en_fr                          7928    -> 26987
en_it                          7928    -> 26987
location_language              10884   -> 11172
en_es                          99      -> 26987"""

from collections import defaultdict

detectors = defaultdict(list)
executors = defaultdict(list)
for line in cool_feats.split("\n"):
    task_name, _, rest = line.partition(" ")
    source, target = [int(x.strip()) for x in rest.split("->")]
    detectors[source].append(task_name)
    executors[target].append(task_name)


# In[23]:


detectors, executors


# In[51]:


with open("redacted/attn_results_4.jsonl") as f:
    results = [json.loads(line) for line in f]

results = pd.DataFrame(results)
task_results = results[results["task"] == "person_profession"]

task_results
    


# In[47]:


import json
import pandas as pd
import numpy as np





with open("redacted/attn_results_4.jsonl") as f:
    results = [json.loads(line) for line in f]
print(results)

results = pd.DataFrame(results)

all_tasks = results["task"].unique().tolist()
all_detectors = results["source"].unique().tolist()
all_executors = results["target"].unique().tolist()

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go

side = int(np.ceil(np.sqrt(len(all_tasks))).astype(int))

print(side)

fig = make_subplots(rows=side, cols=side, subplot_titles=all_tasks)

for i, task in enumerate(all_tasks):
    task_results = results[results["task"] == task]
    


    heatmap = np.zeros((len(all_detectors), len(all_executors)))   

    for j, detector in enumerate(all_detectors):
        for k, executor in enumerate(all_executors):
            entry = task_results[(task_results["source"] == detector) & (task_results["target"] == executor)]
            if len(entry) == 0:
                heatmap[j, k] = 0
            else:
                heatmap[j, k] = np.mean(entry["proportions"].values[0])

    fig.add_trace(go.Heatmap(z=heatmap, x=[executors[x][0] for x in all_executors], y=[detectors[x][0] for x in all_detectors]),
                  row=i // side + 1, col=i % side + 1)

# Update axis labels for all subplots
for row in range(1, side + 1):
    for col in range(1, side + 1):
        fig.update_xaxes(title_text="Executor", row=row, col=col)
        fig.update_yaxes(title_text="Detector", row=row, col=col)

fig.update_layout(height=1500, width=1500)

# fig.write_image("redacted/data/linear_connections.png")
fig


# In[ ]:


from sprint.task_vector_utils import load_tasks

tasks = load_tasks()


# In[52]:


import json
import pandas as pd
import numpy as np
import plotly.express as px

# Load the data
with open("redacted/attn_results_4.jsonl") as f:
    results = [json.loads(line) for line in f]

results = pd.DataFrame(results)



results["source"] = results["source"].map(detectors).apply(lambda x: ", ".join(x))
results["target"] = results["target"].map(executors).apply(lambda x: ", ".join(x))

_all_tasks = results["task"].unique().tolist()


all_tasks = list(tasks.keys())

print(
    [x for x in all_tasks if x not in _all_tasks]
)

all_detectors = all_tasks  # Sort detectors
all_executors = all_tasks  # Sort executors to mirror detectors

# Prepare data for splatter plot
scatter_data = {
    'task': [],
    'detector': [],
    'executor': [],
    'value': []
}

for task in all_tasks:
    task_results = results[results["task"] == task]

    max_value = task_results["proportions"].apply(lambda x: np.mean(x)).max()
    # max_value = 1
    
    for detector in all_detectors:
        for executor in all_executors:
            entry = task_results[(task_results["source"].str.contains(detector)) & (task_results["target"].str.contains(executor))]
            value = np.mean(entry["proportions"].values[0]) / max_value if len(entry) > 0 else 0
            

            if len(entry) > 0:
                detector_tasks = entry["source"].values[0]
                if task in detector_tasks and task != detector:
                    task = detector
            scatter_data['task'].append(task)
            scatter_data['detector'].append(detector)
            scatter_data['executor'].append(executor)
            scatter_data['value'].append(value)

# Create DataFrame for scatter plot
scatter_df = pd.DataFrame(scatter_data)

# Create scatter plot with sorted x and y labels
fig = px.scatter(
    scatter_df,
    x="executor",
    y="detector",
    size="value",
    color="task",
    hover_data={"task": True, "detector": True, "executor": True, "value": True},
    title="Scatter Plot of Task, Detector, and Executor Proportions"
)

# Sort the x-axis to mirror the sorted y-axis (detectors)
fig.update_layout(
    height=1000,
    width=1000,
    xaxis_categoryorder='array',
    xaxis_categoryarray=all_detectors  # Sorting x-axis in the same order as detectors
)


fig.write_image("images/fig1.svg")
# Display the plot
fig.show()


# In[21]:


import json
import pandas as pd
import numpy as np
import plotly.express as px

# Load the data
with open("redacted/attn_results_4.jsonl") as f:
    results = [json.loads(line) for line in f]

results = pd.DataFrame(results)

results["source"] = results["source"].map(detectors).apply(lambda x: ", ".join(x))
results["target"] = results["target"].map(executors).apply(lambda x: ", ".join(x))


all_tasks = results["task"].unique().tolist()
all_tasks = sorted(all_tasks)
all_detectors = sorted(all_tasks)  # Sort detectors
all_executors = sorted(all_tasks)  # Sort executors to mirror detectors

# Prepare data for splatter plot
scatter_data = {
    'task': [],
    # 'detector': [],
    'executor': [],
    'value': []
}

for task in all_tasks:
    task_results = results[results["task"] == task]

    # print(task_results["source"].str.contains(task))

    max_value = task_results["proportions"].apply(lambda x: np.mean(x)).max()
    max_value = 1
    
    # for detector in all_detectors:
    for executor in all_executors:
        entry = task_results[(task_results["source"].str.contains(task)) & (task_results["target"].str.contains(executor))]
        value = np.mean(entry["proportions"].values[0]) / max_value if len(entry) > 0 else 0
        
        scatter_data['task'].append(task)
        # scatter_data['detector'].append(detector)
        scatter_data['executor'].append(executor)
        scatter_data['value'].append(value)

# Create DataFrame for scatter plot
scatter_df = pd.DataFrame(scatter_data)

# Create scatter plot with sorted x and y labels
fig = px.scatter(
    scatter_df,
    y="task",
    x="executor",
    size="value",
    # color="task",
    hover_data={"task": True, "executor": True, "value": True},
    title="Scatter Plot of Task, Detector, and Executor Proportions"
)

# Sort the x-axis to mirror the sorted y-axis (detectors)
fig.update_layout(
    height=1000,
    width=1000,
    xaxis_categoryorder='array',
    xaxis_categoryarray=all_detectors  # Sorting x-axis in the same order as detectors
)

# Display the plot
fig.show()


# In[22]:


import json
import pandas as pd
import numpy as np
import plotly.express as px

# Load the data
with open("redacted/attn_results_4.jsonl") as f:
    results = [json.loads(line) for line in f]

results = pd.DataFrame(results)

results["source"] = results["source"].map(detectors).apply(lambda x: ", ".join(x))
results["target"] = results["target"].map(executors).apply(lambda x: ", ".join(x))

all_tasks = results["task"].unique().tolist()
all_tasks = sorted(all_tasks)
all_detectors = sorted(all_tasks)  # Sort detectors
all_executors = sorted(all_tasks)  # Sort executors to mirror detectors

# Prepare data for heatmap
heatmap_data = np.zeros((len(all_tasks), len(all_executors)))

for i, task in enumerate(all_tasks):
    task_results = results[results["task"] == task]

    max_value = task_results["proportions"].apply(lambda x: np.mean(x)).max()
    max_value = 1  # Normalize to 1 as per original logic

    for j, executor in enumerate(all_executors):
        entry = task_results[(task_results["source"].str.contains(task)) & (task_results["target"].str.contains(executor))]
        value = np.mean(entry["proportions"].values[0]) / max_value if len(entry) > 0 else 0
        
        heatmap_data[i, j] = value


heatmap_data = np.flip(heatmap_data, axis=0)  # Flip the data to match the order of detectors

# Create heatmap with sorted x and y labels
fig = px.imshow(
    heatmap_data,
    labels=dict(x="Executor", y="Task", color="Proportion Value"),
    x=all_executors,
    y=all_tasks[::-1],
    title="Heatmap of Task, Detector, and Executor Proportions",
    color_continuous_scale="Viridis"  # Or any other preferred color scale
)

# Sort the x-axis to mirror the sorted y-axis (detectors)
fig.update_layout(
    height=1000,
    width=1000
)

# Display the plot
fig.show()


# In[112]:


import json
import pandas as pd
import numpy as np
import plotly.express as px

# Load the data
with open("redacted/attn_results_4.jsonl") as f:
    results = [json.loads(line) for line in f]

results = pd.DataFrame(results)

results["source"] = results["source"].map(detectors).apply(lambda x: ", ".join(x))
results["target"] = results["target"].map(executors).apply(lambda x: ", ".join(x))

# all_tasks = results["task"].unique().tolist()

# all_tasks = results["task"].unique().tolist()
all_tasks = list(tasks.keys())
# all_tasks = sorted(all_tasks)
all_detectors = all_tasks  # Sort detectors
all_executors = all_tasks  # Sort executors to mirror detectors

# Prepare data for heatmap
# heatmap_data = np.zeros((len(all_tasks), len(all_executors)))
heatmap_data = [[[] for _ in range(len(all_tasks))] for _ in range(len(all_detectors))]


for i, task in enumerate(all_tasks):
    task_results = results[results["task"] == task]

    max_value = task_results["proportions"].apply(lambda x: np.mean(x)).max()
    # max_value = 1  # Normalize to 1 as per original logic


    for k, detector in enumerate(all_detectors):
        for j, executor in enumerate(all_executors):
            entry = task_results[(task_results["source"].str.contains(detector)) & (task_results["target"].str.contains(executor))]
            value = np.mean(entry["proportions"].values[0]) / max_value if len(entry) > 0 else 0
            
            if value > 0:
                heatmap_data[k][j].append(value)

            # heatmap_data[k, j] += value

heatmap_data = np.array([[np.mean([0] if not x else x) for x in row] for row in heatmap_data])


heatmap_data = np.flip(heatmap_data, axis=0)  # Flip the data to match the order of detectors

# Create heatmap with sorted x and y labels
fig = px.imshow(
    heatmap_data,
    labels=dict(x="Executor", y="Detector", color="Effect strength"),
    x=all_executors,
    y=all_tasks[::-1],
    # title="Effect of disabling detector features on executor features across all tasks",
    color_continuous_scale="Blues"  # Or any other preferred color scale
)

# Sort the x-axis to mirror the sorted y-axis (detectors)
fig.update_layout(
    height=800,
    width=600
)

fig.update_layout(width =400, height=300, 
                font_family="Serif", font_size=7, 
                margin_l=5, margin_t=5, margin_b=5, margin_r=5)

import plotly.io as pio
pio.write_image(fig, "redacted/images/detectors-executors.pdf", width =400, height=300)

# fig.write_image("redacted/images/detectors-executors.svg")

# Display the plot
fig.show()


# In[17]:


get_ipython().system('ls')


# In[117]:


import json

with open("redacted/detector_heatmap_l11.json") as f:
    data = json.load(f)


import plotly.express as px

x_labels = [str(x) for x in data["features"]]
y_labels = [str(x).replace("_", "_") for x in data["task_names"]]

heatmap = np.array(data["heatmap"])

heatmap = np.flip(heatmap, axis=0)

heatmap[heatmap == 0.2] = 0

y_labels = y_labels[::-1]

heatmap = heatmap[:, :40]
x_labels = x_labels[:40]

fig = px.imshow(heatmap, x=x_labels, y=y_labels, width=1000, height=600, color_continuous_scale="Blues", 
                labels=dict(x="Feature", y="Task", color="Effect strength"))

fig.update_layout(width =520, height=250, 
                font_family="Serif", font_size=7, 
                margin_l=5, margin_t=5, margin_b=5, margin_r=5)

import plotly.io as pio
pio.write_image(fig, "redacted/images/detector_heatmap_l11.pdf", width =520, height=250)

# fig.write_image("redacted/images/executor_heatmap_l12.pdf")

fig.show()


# In[118]:


import json

with open("redacted/executor_heatmap_l12.json") as f:
    data = json.load(f)


import plotly.express as px

x_labels = [str(x) for x in data["features"]]
y_labels = [str(x).replace("_", "_") for x in data["task_names"]]

heatmap = np.array(data["heatmap"])

heatmap = np.flip(heatmap, axis=0)

heatmap[heatmap == 0.2] = 0

y_labels = y_labels[::-1]

heatmap = heatmap[:, :40]
x_labels = x_labels[:40]

fig = px.imshow(heatmap, x=x_labels, y=y_labels, width=1000, height=600, color_continuous_scale="Blues", labels=dict(x="Feature", y="Task", color="Effect strength"))

fig.update_layout(width =520, height=250, 
                font_family="Serif", font_size=7, 
                margin_l=5, margin_t=5, margin_b=5, margin_r=5)

import plotly.io as pio
pio.write_image(fig, "redacted/images/executor_heatmap_l12.pdf", width =520, height=250)

# fig.write_image("redacted/images/executor_heatmap_l12.pdf")

fig.show()


# In[23]:


import json

with open("cleanup_results_final.jsonl") as f:
    lines = f.readlines()
    tv_results = [json.loads(line) for line in lines]

with open("cleanup_results_final_ito_5.jsonl") as f:
    lines = f.readlines()
    tv_results_extra = [json.loads(line) for line in lines]


# task_name = "antonyms"

# results = [r for r in results if r["task"] == task_name]


import pandas as pd
df = pd.DataFrame(tv_results)

df_extra = pd.DataFrame(tv_results_extra)

#merge using layer and task

df = pd.merge(df, df_extra, on=["layer", "task"], suffixes=("", "_extra"))

print(
    df.columns
)

losses = [
    "loss", "recon_loss", "ito_loss", "tv_loss", "ito_loss_extra"
]

for loss in losses:
    df[loss] = (df[loss] - df["zero_loss"]) / df["zero_loss"]

name_map = {
    "loss": "Cleaning",
    "recon_loss": "SAE reconstruction",
    "ito_loss": "ITO (20)",
    "tv_loss": "Original Task vector",
    "ito_loss_extra": "ITO (5)"
}


df = df.rename(columns=name_map)

losses = list(name_map.values())

df = df[[
    "layer", "task"] + losses
]

# average over all tasks


# df = df.groupby(["layer", "task"]).mean().reset_index()

df_avg = df.groupby("layer", as_index=False)[losses].mean()
# print(df_avg)
# melted_df = df.melt(id_vars=["layer", "task"], value_vars=["loss", "tv_loss", "ito_loss", "recon_loss"], var_name="loss_type", value_name="loss value")


# print(melted_df)

import plotly.express as px

fig = px.line(df_avg, x="layer", y=losses, labels=dict(value="Average relative loss change", layer="Layer", task="Task", variable="Reconstruction type"))

fig

fig.update_layout(width =300, height=200, 
                font_family="Serif", font_size=10, 
                margin_l=5, margin_t=20, margin_b=5, margin_r=5)


# fig = px.line(melted_df, x="layer", y="loss value", color="loss_type", title="Effect of disabling features on task performance")

# fig

# fig.update_layout(width =520, height=250,
#                 font_family="Serif", font_size=7, 
#                 margin_l=5, margin_t=5, margin_b=5, margin_r=5)

import plotly.io as pio

pio.write_image(fig, "redacted/images/layerwise_performance.pdf", width =300, height=200    )

fig.show()


# In[4]:


import json
import jax
from redacted.utils.load_sae import get_redacted_it_sae_suite


with open("cleanup_results_final.jsonl") as f:
    lines = f.readlines()
    tv_results = [json.loads(line) for line in lines]

with open("cleanup_results_final_ito_5.jsonl") as f:
    lines = f.readlines()
    tv_results_extra = [json.loads(line) for line in lines]


# task_name = "antonyms"

# results = [r for r in results if r["task"] == task_name]


import pandas as pd

df = pd.DataFrame(tv_results)

df_extra = pd.DataFrame(tv_results_extra)

#merge using layer and task

df = pd.merge(df, df_extra, on=["layer", "task"], suffixes=("", "_extra"))

print(df)

# def sae_to_threshold(sae):
#     s = jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"]
 
#     threshold = np.maximum(0, sae["b_gate"] - sae["b_enc"] * s)
#     return threshold


# layers = df["layer"].unique().tolist()

#     thresholds = {
#         l: sae_to_threshold(get_redacted_it_sae_suite(l)) for l in layers
#     }
#     def l0(x, l):
#         threshold = thresholds[l]
#         return np.sum(x > threshold)



# df["l0"] = df[["layer", "weights"]].apply(lambda x: l0(x["weights"], x["layer"]), axis=1)

# df

