# %%
import os
#os.environ["TRANSFORMERS_CACHE"] = "~/.cache/huggingface/transformers"
from imports import *
from jack_plotly import *
from weights_composer import re_get_single_component, get_ov
from transformer_lens.utils import composition_scores
from transformer_lens import FactoredMatrix
import plotly.io as pio
from scipy.stats import zscore
from pprint import pprint

# %%
model = HookedTransformer.from_pretrained("gpt2-small",  fold_value_biases=True, refactor_factored_attn_matrices=True)
#model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m")
# %%
n_heads= model.cfg.n_heads
n_lays= model.cfg.n_layers
d_head = model.cfg.d_head

def get_qk(l, h):
    return model.blocks[l].attn.QK[h]
# %%
src_layer, src_head = 8,6 #gpt 2 small
#src_layer, src_head = 4, 11 #pythia
src_usv = model.blocks[src_layer].attn.OV[src_head].svd()
dest_layer, dest_head = 9,9
decomp = True
#dest_layer, dest_head, comp_idx = 7,9,6 #gpt 2 small
#dest_layer, dest_head, comp_idx = 8,6,2 #gpt 2 small
#dest_layer, dest_head, comp_idx = 8,10,1 #gpt 2 small
#dest_layer, dest_head, comp_idx = 6,6,2 #pythia
#dest_layer, dest_head = 6, 6 #pythia
#dest_usv = model.blocks[dest_layer].attn.OV[dest_head].svd()
dest_usv = model.blocks[dest_layer].attn.QK[dest_head].svd()

#right = get_ov(model, dest_layer, dest_head)
#right = get_qk(dest_layer, dest_head)
#right = model.blocks[dest_layer].attn.QK[dest_head].T
#dest_usv = get_ov(model, dest_layer, dest_head).svd()
#dest_usv = get_qk(dest_layer, dest_head).svd()
comp_idx = 0 # gpt2 small
#comp_idx = 2 # pythia
if decomp:
    print('decomposing right side to comp', comp_idx)
    right = re_get_single_component(*dest_usv, comp_idx)
#


# Q composition
#right = get_qk(dest_layer, dest_head)
#right = right.svd()
#comp_idx=0
#right = re_get_single_component(*right, comp_idx)

#right = model.blocks[dest_layer].attn.QK[dest_head]

# K composition
#right = model.blocks[dest_layer].attn.QK[dest_head].T

scores =[]
#create a heatmap of all heads before layer 7
def exhaustive_heatmap(dest_layer, right):
    heatmap = np.zeros((dest_layer, n_heads))
    for layer in range(dest_layer):
        print(layer)
        for head in range(n_heads):
            src_usv = model.blocks[layer].attn.OV[head].svd()
            scores = []
            for i in range(d_head):
                src = re_get_single_component(*src_usv, i)
                scores.append(composition_scores(src, right).item())
            heatmap[layer, head] = max(scores)
    return heatmap

def simple_heatmap(dest_layer, right):
    #only look at the layers and heads, not all components
    heatmap = np.zeros((dest_layer, n_heads))
    for layer in range(dest_layer):
        print(layer)
        for head in range(n_heads):
            src = model.blocks[layer].attn.OV[head]
            scores = []
            #src = re_get_single_component(*src_usv, comp_idx)
            scores.append(composition_scores(src, right).item())
            heatmap[layer, head] = max(scores)
    return heatmap
# %%
exh = False
if exh:
    heatmap = exhaustive_heatmap(dest_layer, right)
else:
    heatmap = simple_heatmap(dest_layer, right)
# %%
title = f'Composition Scores to {dest_layer}.{dest_head}'
if decomp:
    title = f'Composition Scores to {dest_layer}.{dest_head}.{comp_idx}'
heatmap = imshow(heatmap, return_fig=True, xaxis='Head', yaxis='Layer', title=title)
# %%
heatmap.show()
# %%
if decomp:
    pio.write_image(heatmap, f'gpt2_comp_to_{dest_layer}.{dest_head}.{comp_idx}.pdf', format='pdf')
else:
    pio.write_image(heatmap, f'gpt2_comp_to_{dest_layer}.{dest_head}.pdf', format='pdf')
# %%

scatter(np.arange(heatmap.size), zscore(heatmap.flatten()))
# %%
imshow(zscore(heatmap.flatten()).reshape(dest_layer,12), xaxis='Head', yaxis='Layer', title=f'Composition Scores for Comp. Matrices to {dest_layer}.{dest_head}')
# %%
from pprint import pprint
#pprint(heatmap)
if exh:
    print("Yes exh")
    fname = f'exp_site/results/composition_scores/gpt2_small_exh_to_{dest_layer}_{dest_head}'
else:
    fname = f'exp_site/results/composition_scores/gpt2_small_to_{dest_layer}_{dest_head}'
if decomp:
    print("yes decomp", comp_idx)
    fname += f'_{comp_idx}'
fname+= '.npy'
print("Saving", fname)
np.save(fname, heatmap)
# %%
#fname = f'exp_site/results/composition_scores/gpt2_small_exh_to_{dest_layer}_{dest_head}'
#heatmap = np.load(fname)
#imshow(zscore(heatmap.flatten()).reshape(dest_layer,12), xaxis='Head', yaxis='Layer', title=f'Composition Scores for Comp. Matrices to {dest_layer}.{dest_head}')


# %%
src_layer, src_head = 3,0 #gpt 2 small
left = get_ov(model, src_layer, src_head)
#left = re_get_single_component(*left.svd(), 0)
def rightside_heatmap(left, dest_usv):
    scores = []
    for comp_idx in range(d_head):
        right = re_get_single_component(*dest_usv, comp_idx)
        scores.append(composition_scores(left, right).item())
    return scores

def leftside_heatmap(src_usv, right):
    scores = []
    for comp_idx in range(d_head):
        src = re_get_single_component(*src_usv, comp_idx)
        scores.append(composition_scores(src, right).item())
    return scores

#dest_usv = model.blocks[dest_layer].attn.QK[dest_head].T.svd()
#dest_usv = get_ov(model , dest_layer,dest_head).svd()
print("DEST USV FROM", dest_layer, dest_head)
#dest_usv = get_qk(dest_layer, dest_head).svd()
dest_usv=get_ov(model,dest_layer, dest_head).svd()
#dest_usv = get_qk(dest_layer, dest_head).T.svd()

scores = rightside_heatmap(left, dest_usv)
fig = scatter(np.arange(d_head),scores, return_fig=True,title=f'Composition Scores for {src_layer}.{src_head} to Components of {dest_layer}.{dest_head}')
fig.show()

src_usv = get_ov(model, src_layer,src_head).svd()

right = get_ov(model, dest_layer,dest_head)
#right = get_qk(dest_layer, dest_head)
#right = get_qk(dest_layer, dest_head).T
right = re_get_single_component(*right.svd(), 6)

#right = get_qk(dest_layer, dest_head).T
#dest_usv=get_ov(model,dest_layer, dest_head).svd()
#right = re_get_single_component(*dest_usv, 6)

scores = leftside_heatmap(src_usv, right)

from utils import fig_to_json
fig = scatter(np.arange(d_head),scores, return_fig=True, yaxis='Composition Score', xaxis="Component", title=f'Composition Scores for {src_layer}.{src_head} Components to {dest_layer}.{dest_head}.6 Values')
fig.show()
pio.write_image(fig, f'{src_layer}.{src_head}comps_to_{dest_layer}.{dest_head}.6_values.pdf', format='pdf')
#fig_to_json(fig, 'exp_site/results/composition_scores/3_0_comps_to_7_9_6_leftside.json')
#fig.show()
# %%
pio.write_image(fig, f'{src_layer}.{src_head}_to_{dest_layer}.{dest_head}comps_values.pdf', format='pdf')

pprint(np.argsort(scores)[::-1])
# %%
scatter(np.arange(d_head),dest_usv[1])

# %%

def set_src(lay, hed, comp_type='V'):

    if comp_type=='V':
        src_usv = model.blocks[lay].attn.OV[hed].svd()
    return src_usv

src_usv = set_src(3,3)

#right = get_ov(model, 5,5)
right = get_qk(5,5).T
#comp_idx=2
#right = re_get_single_component(*right.svd(), comp_idx)
scores = []
for i in range(d_head):
    src = re_get_single_component(*src_usv, i)
    scores.append(composition_scores(src, right).item())
zscores = zscore(scores)
fig = go.Figure()
scat = go.Scatter(x=np.arange(len(scores)), y=zscores, mode='markers', text=scores)
fig.add_trace(scat)
fig.show()

# %%
scatter(np.arange(len(src_usv[1])), src_usv[1])
# %%
components_sorted = np.argsort(zscores)[::-1]
pprint(components_sorted)
# %%

counts, bin_edges = np.histogram(scores, range=(-.05, 1.0), bins=d_head)
fig = tpl.figure()
fig.hist(counts, bin_edges, orientation="horizontal", force_ascii=False)
fig.show()
#fig = tpl.figure()
#x = np.arange(len(scores))
#fig.plot(x, scores, width=60, height=20)
#fig.show()
# %%
text = ' a b c d a b c d a b c' #' a b c a b c a b c a b c a b c a b c a b'
_, cache = model.run_with_cache(text)
plot_attn_pattern_from_cache(cache, 4, tokens=model.to_str_tokens(text))
# %%
text = 'Then, John and Mary went to the store. John gave a drink to'
_, cache = model.run_with_cache(text)
plot_attn_pattern_from_cache(cache, 6, tokens=model.to_str_tokens(text))
# %%
#text = 'Neurons are cells found in the brain. They are made up of a cell body, dendrites, and an axon. Neurons are responsible for transmitting information throughout the body.'
text ='At the farm, I counted the cows and pigs. The cows and pigs are looking pretty plump'
_, cache = model.run_with_cache(text)
plot_attn_pattern_from_cache(cache, 6, tokens=model.to_str_tokens(text))
# %%
text = "Photosynthetic communicative channels for discombobulating the effects of music, art, and STEM education"
_, cache = model.run_with_cache(text)
plot_attn_pattern_from_cache(cache, 6, tokens=model.to_str_tokens(text))

# %%
ov = get_ov(model, 9, 8)
def metric(mat):
    return mat.trace()/mat.abs().sum().item()
metric(ov.AB)

# %%
ov = get_ov(model, 9, 6)
metric(ov.AB)
# %%
movers = [0,6,7,9]
scores = []
for i in range(12):
    ov = get_ov(model, 9, i)
    s = metric(ov.AB)
    if i in movers:
        print(s, i, "Mover")
    else:
        print(s, i)
    scores.append(s)
# %%
scoreso = np.argsort(scores)
print(scoreso)
for i in scoreso:
    if i in movers:
        print(scores[i].item(), '\t', i, "\tMover")
    else:
        print(scores[i].item(), '\t', i)

# %%
text = " Today, when I go to the store, I will buy a plate, a pear, a pen, and a textbook. First, I will get the plate, the pen, the pear, and then the"
logits, cache = model.run_with_cache(text)
model.to_string(logits[0,-1].argmax(-1))

# %%
text = text = " Today, when I go to the store, I will buy a plate, a textbook, a pear, and a pen. First, I will get the plate, the pen, the pear, and then the"
logits, cache = model.run_with_cache(text)
model.to_string(logits[0,-1].argmax(-1))
# %%
text = " Today, when I go to the store, I will buy a plate, a pear, a pen, and a textbook. First, I will get the textbook, the pen, the pear, and then the"
logits, cache = model.run_with_cache(text)
model.to_string(logits[0,-1].argmax(-1))

# %%
text = text = " Today, when I go to the store, I will buy a textbook, a plate, a pen, and a pear. First, I will get the plate, the pen, the pear, and then the"
logits, cache = model.run_with_cache(text)
model.to_string(logits[0,-1].argmax(-1))
# %%
