#!/usr/bin/env python
# coding: utf-8

# In[1]:


import os
if "models" not in os.listdir("."):
    os.chdir("../..")


# In[2]:


import json

with open("cleanup_results_final_detectors_2.jsonl") as f:
    lines = f.readlines()
    results = [json.loads(line) for line in lines]


# In[3]:


# task_name = "antonyms"

# results = [r for r in results if r["task"] == task_name]


import pandas as pd

df = pd.DataFrame(results)

melted_df = df.melt(id_vars=["layer", "task"], value_vars=["loss", "tv_loss", "ito_loss", "recon_loss"], var_name="loss_type", value_name="loss value")


# In[4]:


melted_df.head()


# In[5]:


import plotly.express as px

task_name = "country_capital"

px.line(melted_df[melted_df["task"] == task_name], x="layer", y="loss value", color="loss_type")


# In[6]:


task_names = df["task"].unique()
layers = sorted(df["layer"].unique())

task_losses = {task_name: {loss_type: df[df["task"] == task_name][loss_type].to_numpy() for loss_type in ["loss", "tv_loss", "ito_loss"]} for task_name in task_names}

normalized_losses = {task_name: {loss_type: (task_losses[task_name][loss_type] - task_losses[task_name][loss_type].min()) / (task_losses[task_name][loss_type].max() - task_losses[task_name][loss_type].min()) for loss_type in ["loss", "tv_loss"]} for task_name in task_names}


# In[7]:


heatmap = pd.DataFrame({task_name: normalized_losses[task_name]["tv_loss"] for task_name in task_names}, index=layers)
px.imshow(heatmap)


# In[8]:


get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()


# In[9]:


from redacted.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")


# In[10]:


from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"


# In[11]:


from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()


# In[12]:


# task_names = ["en_es", "antonyms", "person_profession", "es_en", "present_simple_gerund", "present_simple_past_simple", "person_profession", "person_language", "country_capital", "football_player_position"]
# task_name = task_names[1]
task_names = list(tasks.keys())


# In[13]:


import jax.numpy as jnp
import jax


from sprint.task_vector_utils import ICLRunner, logprob_loss
from redacted.llama import LlamaBlock
from redacted.sampling import sample, jit_wrapper


get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)



def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs


# In[14]:


from sprint.icl_sfc_utils import AblatedModule
layer = 11
# mask_name = "arrow"


# In[15]:


from redacted.utils.load_sae import get_redacted_it_sae_suite


sae = get_redacted_it_sae_suite(layer=layer)


# In[16]:


tmp = [x for x in results if x["task"] == "location_language" and x["layer"] == 11]


# In[17]:


import jax
import numpy as np

weights = np.array(tmp[0]["gw"])

jax.lax.top_k(weights, 10)


# In[ ]:





# In[18]:


import numpy as np

features = []

for task_name in task_names:
    task_results = [result for result in results if result["task"] == task_name and result["layer"] == layer]

    for result in task_results:
        w = np.array(result["weights"])
        # s = jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"]
        # threshold = jnp.maximum(0, sae["b_gate"] - sae["b_enc"] * s)
        w = w * (w > 0)

        # print(threshold)

        features += np.nonzero(w)[0].tolist()

features = list(set(features))

len(features)


# In[19]:


import dataclasses
from tqdm.auto import tqdm
from functools import partial
from redacted.utils.activation_manipulation import add_vector

task_losses_positive = {}

n_few_shots, batch_size, max_seq_len = 20, 12, 256
seed = 10

prompt = "Follow the pattern:\n{}"

def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

prompt_length = len(tokenizer.encode(prompt))

for task_name in tqdm(task_names):

    sep = 3978
    pad = 0
    newline = 108


    pairs = list(tasks[task_name].items())

    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 8

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=1, max_seq_len=max_seq_len, seed=seed, prompt=prompt, vector_type="detector")

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.eval_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    scale = 25

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = tokens == newline
    mask = jnp.roll(mask, -1, axis=-1)
    mask = mask.at[:, :prompt_length].set(False)

    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)
    positions = sorted_indices[:, :1]
    
    def steer_with_direction(direction):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)
        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        return logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2)

    task_losses_positive[task_name] = [[steer_with_direction(sae["W_dec"][feature]).tolist() for feature in tqdm(features)]]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    task_losses_positive[task_name].append(logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2).tolist())


# In[18]:


import dataclasses
from tqdm.auto import tqdm
from functools import partial
from redacted.utils.activation_manipulation import add_vector

negative_task_losses = {}

n_few_shots, batch_size, max_seq_len = 20, 16, 256
seed = 10

prompt = "Follow the pattern:\n{}"

def calc_acc(tokens, sep, logits, runner):
    arrow_pos = jnp.nonzero(tokens == sep)
    arrow_pos_single = []
    for i in range(batch_size):
        arrow_pos_single.append(arrow_pos[1][arrow_pos[0] == i].max())

    arrow_pos_single = np.array(arrow_pos_single)

    hits = 0

    for i, (ap, l) in enumerate(zip(arrow_pos_single, logits)):
        l = l.argmax(-1)
        tgt = runner.eval_pairs[i][-1][1]
        hits += int(tgt in repr(tokenizer.decode(l[ap:ap+3])))
    return hits / runner.eval_batch_size

def calc_acc(tokens, sep, logits, runner):
    logits = logits.argmax(-1)
    logits = logits[:, :-1]
    tokens = tokens[:, 1:]

    mask = tokens == sep

    hits = tokens == logits

    hits = hits * mask

    hits = hits.sum()
    return hits / mask.sum()




def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

for task_name in tqdm(task_names):

    sep = 3978
    pad = 0


    pairs = list(tasks[task_name].items())

    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 16

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    scale = 30

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = train_tokens == sep
    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)

    k = jnp.sum(mask[0]).astype(int)

    positions = sorted_indices[:, :k]
    
    def steer_with_direction(direction):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)
        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        acc = calc_acc(train_tokens, sep, logits, runner)

        return logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2), acc

    negative_task_losses[task_name] = [[steer_with_direction(-sae["W_dec"][feature]) for feature in tqdm(features)]]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    acc = calc_acc(train_tokens, sep, logits, runner)
    negative_task_losses[task_name].append((logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2), acc))


# In[21]:


import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = np.minimum(losses, base_loss)

    losses = (losses - base_loss) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

avg_heatmap = np.sum(heatmap != 1, axis=0)

sorted_idx = np.argsort(avg_heatmap)

heatmap = heatmap[:, sorted_idx]


fig = px.imshow(heatmap, x=[str(features_dropped[x]) for x in sorted_idx], y=task_names)

fig.show()


# In[20]:


import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = np.minimum(losses, base_loss)

    losses = (base_loss - losses) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)

    losses = np.clip(losses, 0.2, 10)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.log(1 + heatmap * 40)

# heatmap = np.clip(heatmap, -5, 5)

# avg_heatmap = np.sum(heatmap != 1, axis=0)

avg_heatmap  = np.max(heatmap, axis=0)

sorted_idx = np.argsort(-avg_heatmap)

heatmap = heatmap[:, sorted_idx]

# min_heatmap = np.min(heatmap, axis=1)   

min_pos = np.argmax(heatmap, axis=1)

y_sorted_idx = np.argsort(min_pos)

heatmap = heatmap[y_sorted_idx]

fig = px.imshow(heatmap[:, :100], x=[str(features_dropped[x]) for x in sorted_idx][:100], y=[task_names[x] for x in y_sorted_idx], width=2000, height=600, aspect="auto", color_continuous_scale="Blues", title="Positive steering with detectors on L11")

fig.show()


# In[21]:


with open("redacted/detector_heatmap_l11.json", "w") as f:
    json.dump({"heatmap": heatmap.tolist(), "features": [features_dropped[x] for x in sorted_idx], "task_names": [task_names[x] for x in y_sorted_idx]}, f)


# In[31]:


import dataclasses
from tqdm.auto import tqdm
from functools import partial
from redacted.utils.activation_manipulation import add_vector

task_losses_positive = {}

n_few_shots, batch_size, max_seq_len = 20, 12, 256
seed = 10

prompt = "Follow the pattern:\n{}"

def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

prompt_length = len(tokenizer.encode(prompt))

for task_name in tqdm(["antonyms"]):

    sep = 3978
    pad = 0
    newline = 108


    pairs = list(tasks[task_name].items())

    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 8

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=1, max_seq_len=max_seq_len, seed=seed, prompt=prompt, vector_type="detector")

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.eval_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    # scale = 25

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = tokens == newline
    mask = jnp.roll(mask, -1, axis=-1)
    mask = mask.at[:, :prompt_length].set(False)

    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)
    positions = sorted_indices[:, :1]

    feature = 11050
    
    def steer_with_direction(direction, scale):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)
        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        return logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2)

    task_losses_positive[task_name] = [[steer_with_direction(sae["W_dec"][feature], scale).tolist() for scale in tqdm(np.logspace(0, 2, 100))]]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    task_losses_positive[task_name].append(logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2).tolist())


# In[32]:


px.line(
    x=np.logspace(0, 2, 100),
    y=task_losses_positive["antonyms"][0]
)


# In[46]:


heatmap.shape


# In[45]:


sorted_idx


# In[29]:


import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = [22113]
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in negative_task_losses.items():
    base_loss = losses[1]
    losses = [x[0].tolist() for x in losses[0]]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = losses - base_loss[0]
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

std_heatmap = np.std(heatmap, axis=0)

sorted_idx = np.argsort(-std_heatmap)

heatmap = heatmap[:, sorted_idx]

labels = [str(features_dropped[x]) for x in sorted_idx]

fig = px.imshow(heatmap, x=labels, y=task_names)

fig.show()


# In[120]:


features_dropped.index(7491)


# In[121]:


mean_heatmap = np.mean(heatmap, axis=0)

drop_features = np.where(mean_heatmap < 0.9)[0]

drop_features = [features_dropped[x] for x in drop_features]

drop_features


# In[115]:


heatmap.shape


# In[116]:


(22 + 0.1)/23


# In[127]:


import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)


    losses = np.minimum(losses, base_loss * 1.5)

    losses = (losses - base_loss) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / ((max_loss - min_loss) + 1e-6)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

heatmap = np.log(heatmap + 1)


fig = px.imshow(heatmap, x=[str(x) for x in features_dropped], y=task_names, color_continuous_scale='YlGnBu')

fig.show()


# In[125]:


task_losses_positive["present_simple_past_perfect"]


# In[84]:


mean_heatmap


# In[71]:


import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in negative_task_losses.items():
    base_loss = losses[1]
    losses = [x[0].tolist() for x in losses[0]]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = losses - base_loss[0]
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)


fig = px.imshow(heatmap, x=[str(x) for x in features_dropped], y=task_names)

fig.show()


# In[ ]:





# In[ ]:




