# %%
from imports import *
from my_plotly import *
from utils import load_json
from fancy_einsum import einsum
from weights_composer import get_ov, re_get_single_component
import numpy as np
import plotly.io as pio
import sys

is_interactive = bool(getattr(sys, 'ps1', sys.flags.interactive))

#layer, head, comp_idx = 7,9,6
#layer, head, comp_idx = 8,6,2
layer, head, comp_idx = 8,10,1

#data_name = 'laundry_list_v3'
data_name = 'laundry_list_250_9objs'
rich.print(layer, head, comp_idx, data_name)

def show_plot(fig):
    if is_interactive:
        fig.show()
    else:
        print("Not in interactive mode, not showing plot")

def fig_to_json(fig, path, **kwargs):
    pio.write_json(fig, path, **kwargs)

def fig_from_json(path):
    return pio.read_json(path)

# %%
model = HookedTransformer.from_pretrained("gpt2-small",  fold_value_biases=True, refactor_factored_attn_matrices=True)
model.set_use_attn_result(True)
# %%
n_heads= model.cfg.n_heads
n_lays= model.cfg.n_layers
d_head = model.cfg.d_head
# %%
ex = {
    "text": " Today, when I go to the store, I will buy a a cup, a bowl, a banana, a pear, a jug, and a bottle. When I go, I will get the cup, the bowl, the banana, the pear, the jug, and then the",
    "n_objs": 6,
    "query_idx": 5,
    "objects": [
      "cup",
      "bowl",
      "banana",
      "pear",
      "jug",
      "bottle"
    ]
  }
# %%

data = load_json(f'datasets/{data_name}.json')
# %%
logits = model(ex["text"], prepend_bos=True)
# %%
def topk(logits, k=10):
    return model.to_str_tokens(logits.topk(k).indices[:, -1])
# %%
print(topk(logits))
# %%
objs = ex["objects"]
query_obj = objs[ex["query_idx"]]
label = ' '+query_obj
#print index of label in logits at last position
print(topk(logits).index(label))
# %%
#function to see which object is highest in the logits
#without looking beyond the top 10
def object_in_highest_position(logits, objs):
    highest_obj = None
    top10_labels = topk(logits, 10)
    highest_idx = 9999
    #print(top10_labels)
    for obj in objs:
        if obj in top10_labels:
            if top10_labels.index(obj) < highest_idx:
                highest_idx = top10_labels.index(obj)
                highest_obj = obj
    return highest_obj

def run_example(ex):
    logits, cache = model.run_with_cache(ex["text"], prepend_bos=True)
    objs = ex["objects"]
    query_obj = objs[ex["query_idx"]]
    label = ' '+query_obj
    all_objs = [' '+obj for obj in objs]
    highest_obj = object_in_highest_position(logits, all_objs)
    return highest_obj, cache

num_correct = 0
from collections import defaultdict
num_correct_by_num_obj = defaultdict(lambda: [])


# %%
ov = get_ov(model, layer, head)
u,s,v = ov.svd()
direction = v[:, comp_idx]
print(direction.shape, layer, head, comp_idx)
inhib_vecs =[]

def project_into_direction(activ, direction):
    assert activ.shape[1:] == direction.shape
    proj = einsum("batch ..., ... -> batch", activ, direction)
    return proj

n_objs_data = []
query_idx_to_pred = {}
for i in track(list(range(len(data)))):
    ex = data[i]
    objs = ex["objects"]
    query_idx = ex["query_idx"]
    label = ' '+objs[query_idx]
    n_objs_data.append(ex["n_objs"])
    highest_obj, cache = run_example(ex)
    #print(f"Label:{label}, Highest obj:{highest_obj}")
    pred_idx = objs.index(highest_obj.strip())
    if query_idx not in query_idx_to_pred:
        query_idx_to_pred[query_idx] = []
    query_idx_to_pred[query_idx].append(pred_idx)
    correct = highest_obj == label
    num_correct += correct
    num_correct_by_num_obj[ex["n_objs"]].append(correct)

    results, labels = cache.stack_head_results(
        layer = layer+1,
        return_labels = True)
    
    head_idx_label = labels.index(f"L{layer}H{head}")
    activation = results[head_idx_label][None, 0, -1]
    projected_dir = project_into_direction(activation, direction)
    inhib_vecs.append(direction*projected_dir) #scale the direction by its amount
print(f"Num correct: {num_correct}")
inhib_vecs = torch.stack(inhib_vecs)
import json

#save results by n objs
with open(f'laundry_list_results/{data_name}_results.json', 'w') as f:
    json.dump(num_correct_by_num_obj, f)

with open(f'laundry_list_results/{data_name}_query_idx_to_pred.json', 'w') as f:
    json.dump(query_idx_to_pred, f)

# %%

with open(f'laundry_list_results/{data_name}_query_idx_to_pred.json', 'r') as f:
    preds = json.load(f)

print(preds['1'])

# %%
#create a heatmap where each row is a query index and each column is a prediction
#the value is the number of times that prediction was made
heat = np.zeros((len(preds), len(preds.keys())))
for query_idx in np.sort([int(i) for i in preds.keys()]):
    pred = preds[str(query_idx)]
    for p in pred:
        heat[query_idx, p] += 1
    heat[query_idx]/=np.sum(heat[query_idx])
fig = px.imshow(heat, labels={'x':'Predicted object index', 'y':'Query index'}, title=f'{data_name} Predicted object index distribution')

# %%
fig_to_json(fig, f"exp_site/results/laundry_list/{data_name}_heatmap.json")
show_plot(fig)

# %%
pred_last = heat[:, -1]
fig = px.bar(x=list(range(len(pred_last))), y=list(pred_last), labels={'x':'Query index', 'y':'Probability of predicting last object'}, title=f'{data_name} Probability of predicting object 20')
show_plot(fig)
# %%
print(preds.keys())
from collections import Counter
print("SUM", sum([len(v) for v in preds.values()]))
#create a chart for each query index showing distribution of predictions
for query_idx in np.sort([int(i) for i in preds.keys()]):
    pred = preds[str(query_idx)]
    cntr_preds= Counter(pred)
    cntr_preds = {k:cntr_preds[k] for k in np.sort(list(cntr_preds.keys()))}
    for i in range(max(cntr_preds.keys())+1):
        if i not in cntr_preds:
            cntr_preds[i] = 0
    #create a bar char with the counter
    fig = px.bar(x=list(cntr_preds.keys()), y=list(cntr_preds.values()), labels={'x':'Predicted object index', 'y':'Count'}, title=f'Predictions for query index {query_idx}')
    #fig = px.histogram(x=pred, nbins=20,labels={'x':'Predicted object index'}, title=f'Predictions for query index {query_idx}')

    show_plot(fig)
    #fig.write_image(f"{data_name}_{layer}_{head}_query_idx_{query_idx}.pdf")


# %%
with open(f'laundry_list_results/{data_name}_results.json') as f:
    num_correct_by_num_obj = json.load(f)
# %%
#plot with plotly a bar chart of the results
accuracy_by_num_obj = {int(k):sum(v)/len(v) for k,v in num_correct_by_num_obj.items()}
print(accuracy_by_num_obj)
# %%
fig = px.bar(x=list(accuracy_by_num_obj.keys()), y=list(accuracy_by_num_obj.values()), labels={'x':'Number of objects', 'y':'Accuracy'}, title='Accuracy of predicting the last object<br>above all others (N=5000)')
# %%
show_plot(fig)
# %%
#save fig as pdf:
fig.write_image(f"laundry_list_results/{data_name}_accuracy.pdf")
# %%
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def projQ(comp_vecs):
    comp_vecs = torch.tensor(comp_vecs).to(device)
    
    out_vec = re_get_single_component(*get_ov(model, layer,head).svd(), comp_idx)
    proj = torch.matmul(comp_vecs, out_vec.AB)#torch.matmul(comp_vecs, model.blocks[7].attn.W_Q[9])+model.blocks[7].attn.b_Q[9]
    proj = model.blocks[9].ln1(proj)
    proj = torch.matmul(proj, model.blocks[9].attn.W_Q[9])+model.blocks[9].attn.b_Q[9]
    #proj = torch.matmul(comp_vecs, get_ov(model, 7,9).AB)#torch.matmul(comp_vecs, model.blocks[7].attn.W_Q[9])+model.blocks[7].attn.b_Q[9]
    proj = proj.detach().cpu().numpy()
    return proj
# %%
from sklearn.decomposition import PCA
pca = PCA(3)
query_pca = pca.fit_transform(cache['query', 9, 'attn'][:, :, 9, :].view(-1, d_head).detach().cpu().numpy())
print(inhib_vecs.shape)
comp_vecs_pca = pca.transform(projQ(inhib_vecs))
np.save(f"laundry_list_results/{data_name}_{layer}_{head}_{comp_idx}_comp_vecs_pca.npy", comp_vecs_pca)
#scatter(comp_vecs_pca[:,0], comp_vecs_pca[:,1], comp_vecs_pca[:,2])
# %%
import plotly.graph_objects as go
scat = go.Scatter3d(x=comp_vecs_pca[:,0], y=comp_vecs_pca[:,1], z=comp_vecs_pca[:,2], mode='markers+text', marker=dict(color=n_objs_data), text = [f'{n}' for n in n_objs_data])
query_scat = go.Scatter3d(x=query_pca[:,0], y=query_pca[:,1], z=query_pca[:,2], mode='markers+text', name='9.9 Queries')

fig = go.Figure(data=[query_scat, scat ])
show_plot(fig)
fig_to_json(fig, f"laundry_list_results/{data_name}_{layer}_{head}_pca.json")
# %%
n_objs_data = [ex["n_objs"] for ex in data]
n_objs_numbers = set(n_objs_data)
n_objs_idxs = {n:[] for n in n_objs_numbers}
n_obj_scats = []
for n in n_objs_numbers:
    n_objs_idxs[n] = [i for i in range(len(n_objs_data)) if n_objs_data[i] == n]
    print(n, n_objs_data.count(n),len(n_objs_idxs[n]))
    #get correct vs incorrect indices within this list
    correct_or_incorrect = num_correct_by_num_obj[str(n)]
    print(correct_or_incorrect.count(True), correct_or_incorrect.count(False))
    correct_idxs =   [n_objs_idxs[n][i] for i,cor in enumerate(correct_or_incorrect) if cor]
    incorrect_idxs = [n_objs_idxs[n][i] for i,cor in enumerate(correct_or_incorrect) if not cor]
    print("COMP CECS correct count", comp_vecs_pca[correct_idxs,0].shape,comp_vecs_pca[incorrect_idxs,0].shape)
    n_obj_scats.append(go.Scatter3d(x=comp_vecs_pca[correct_idxs,0], y=comp_vecs_pca[correct_idxs,1], z=comp_vecs_pca[correct_idxs,2],       marker=dict(color='green'), mode='markers', name=f'{n} correct'))
    n_obj_scats.append(go.Scatter3d(x=comp_vecs_pca[incorrect_idxs,0], y=comp_vecs_pca[incorrect_idxs,1], z=comp_vecs_pca[incorrect_idxs,2], marker=dict(color='red'),mode='markers', name=f'{n} incorrect'))


fig = go.Figure(data=n_obj_scats)
fig.add_trace(go.Scatter3d(
        x=[0, 0],
        y=[-1, 1],
        z=[0, 0],
        mode='lines',
        line=dict(color='black', width=4)
    ))

fig.add_trace(go.Scatter3d(
        x=[-1, 1],
        y=[0, 0],
        z=[0, 0],
        mode='lines',
        line=dict(color='black', width=4)
    ))

fig.add_trace(go.Scatter3d(
        x=[0, 0],
        y=[0, 0],
        z=[-1, 1],
        mode='lines',
        line=dict(color='black', width=4)
    ))
fig_to_json(fig, f"laundry_list_results/{data_name}_per_obj_pca_{layer}_{head}_{comp_idx}.json")
# %%
show_plot(fig)
# %%
print(f"{data_name}_per_obj_pca_{layer}_{head}_{comp_idx}.json")
fig = fig_from_json(f"laundry_list_results/{data_name}_per_obj_pca_{layer}_{head}_{comp_idx}.json")

# %%
print(data_name)
fig = fig_from_json(f"laundry_list_results/{data_name}_{layer}_{head}_pca.json")
# %%
show_plot(fig)

# %%
data = load_json(f'datasets/{data_name}.json')
#n_objs_data = [ex["n_objs"] for ex in data]
n_objs_data = [(ex['query_idx']) for ex in data]
comp_vecs_pca = np.load(f"laundry_list_results/{data_name}_{layer}_{head}_{comp_idx}_comp_vecs_pca.npy")
from collections import defaultdict
high_states = dict.fromkeys(set(n_objs_data), 0)

for d, n in zip(comp_vecs_pca, n_objs_data):
    if d[1] > 0:
        high_states[n] += 1


# %%
#create a bar chart for the high states
fig = px.bar(x=list(high_states.keys()), y=list(high_states.values()), labels={'x':'Number of objects', 'y':'Number of high states'}, title='Number of high states by number of objects')
# %%
show_plot(fig)
# %%
from scipy.stats import pearsonr
print(pearsonr(n_objs_data, comp_vecs_pca[:,1]))
# %%
len(n_objs_data), comp_vecs_pca.shape
# %%
# %%
num_correct_by_num_obj.keys()

