#%%
import torch
import numpy as np
import matplotlib.pyplot as plt
from globals import *
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from plotting_master import period_elephants, newline_elephants, context_elephants
from sklearn.metrics import roc_auc_score
from collections import defaultdict
import pickle
from scipy.stats import spearmanr
import plotly.graph_objects as go
from plotly.subplots import make_subplots

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b",token='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')
all_inputs = torch.load('gemmascope/all_inputs.pt',weights_only=True)
flat_token_strs = tokenizer.convert_ids_to_tokens(all_inputs.flatten())
context_length = 1024

#%%
period_signal = np.array([('.' in t) for t in flat_token_strs])
period_proximities = distance_to_previous_one(period_signal)
period_valid_mask = ~np.isnan(period_proximities)

newline_signal = np.array([('\n' in t) for t in flat_token_strs])
newline_proximities = distance_to_previous_one(newline_signal)
newline_valid_mask = ~np.isnan(newline_proximities)

context_proximities = np.tile(np.arange(context_length),5000)

proximities_dict = {
    'period': (period_proximities,period_valid_mask),
    'newline': (newline_proximities,newline_valid_mask),
    'context': (context_proximities,None)
}

#%%
spearman_dict = {}
for layer in tqdm(np.arange(0,26)):
    spearman_dict[layer] = {}
    freqs, saes, proj_tensor = get_mydata(layer,freqs=True,saes=True,proj_data=True)
    elephants = print_elephants_with_pairs(freqs,saes,0.1)
    actual_elephants = get_elephants_thres(freqs,0.1)
    assert len(elephants) == proj_tensor.shape[0]
    for i,e in enumerate(elephants):
        if e not in actual_elephants:
            continue
        spearman_dict[layer][e] = {}
        acts = proj_tensor[i]
        for key,(proximities,valid_mask) in proximities_dict.items():
            if valid_mask is None:
                spearman = spearmanr(proximities,acts)[0]
            else:
                mask = valid_mask.astype(bool)
                spearman = spearmanr(proximities[mask],acts[mask])[0]
            spearman_dict[layer][e][key] = spearman

with open('final/position_tracking/spearman_dict.pkl', 'wb') as f:
    pickle.dump(spearman_dict, f)

#%%
spearman_dict = load_pickle('final/position_tracking/spearman_dict.pkl')
dotsize = 3
layers = np.arange(0,26)
elephants_dict = {}
for layer in tqdm(layers,desc='Layers'):
    freqs, = get_mydata(layer,freqs=True)
    elephants = get_elephants_thres(freqs,0.1)
    elephants_dict[layer] = [e.item() for e in elephants]

#%%
fig = make_subplots(rows=1, cols=3, 
                    subplot_titles=('Sentence-tracking', 'Paragraph-tracking', 'Context-position-tracking'), 
                    shared_yaxes=True, horizontal_spacing=0.02)

fig.add_trace(
    go.Scatter(x=[None], y=[None], mode='markers',
               marker=dict(size=dotsize, color='#636EFA'),
               name='Sentence',
               showlegend=True),
    row=1, col=1
)
fig.add_trace(
    go.Scatter(x=[None], y=[None], mode='markers',
               marker=dict(size=dotsize, color='#EF553B'),
               name='Paragraph',
               showlegend=True),
    row=1, col=1
)
fig.add_trace(
    go.Scatter(x=[None], y=[None], mode='markers',
               marker=dict(size=dotsize, color='#00CC96'),
               name='Context position',
               showlegend=True),
    row=1, col=1
)

for layer in layers:
    for e in spearman_dict[layer].keys():
        if e not in elephants_dict[layer]:
            continue
        period_score = spearman_dict[layer][e]['period']
        newline_score = spearman_dict[layer][e]['newline']
        context_score = spearman_dict[layer][e]['context']

        color = 'grey'
        if e in newline_elephants[layer]:
            color = '#EF553B'
        if e in period_elephants[layer]:
            color = '#636EFA'
        if e in context_elephants[layer]:
            color = '#00CC96'
            
        fig.add_trace(
            go.Scatter(x=[layer], y=[period_score], mode='markers',
                      marker=dict(size=dotsize, symbol='circle', color=color),
                      text=f'Layer: {layer}<br>Elephant: {e}<br>Metric: {period_score:.3f}',
                      hoverinfo='text',
                      showlegend=False),
            row=1, col=1
        )
        
        fig.add_trace(
            go.Scatter(x=[layer], y=[newline_score], mode='markers',
                      marker=dict(size=dotsize, symbol='circle', color=color),
                      text=f'Layer: {layer}<br>Elephant: {e}<br>Metric: {newline_score:.3f}',
                      hoverinfo='text',
                      showlegend=False),
            row=1, col=2
        )
        
        fig.add_trace(
            go.Scatter(x=[layer], y=[context_score], mode='markers',
                      marker=dict(size=dotsize, symbol='circle', color=color),
                      text=f'Layer: {layer}<br>Elephant: {e}<br>Metric: {context_score:.3f}',
                      hoverinfo='text',
                      showlegend=False),
            row=1, col=3
        )

fig.update_layout(
    width=700,
    height=220,
    showlegend=True,
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="right",
        x=1.24
    ),
    margin=dict(l=50, r=120, t=20, b=10),
    font=dict(size=10)
)

for i in range(1,4):
    fig.update_xaxes(title_text='Layer', 
                     title_font=dict(size=10), 
                     row=1, col=i, 
                     range=[-0.5,25.5])
    
    if i == 1:
        fig.update_yaxes(title_text='Spearman correlation',
                        title_font=dict(size=10),
                        row=1, col=i)
    else:
        fig.update_yaxes(showticklabels=False,
                        row=1, col=i)

fig.update_annotations(font_size=10)
fig.show()
fig.write_image("*PLOTS/proximity.pdf", format='pdf', scale=20)