#%%
import numpy as np
import matplotlib.pyplot as plt
import torch
import pickle
import argparse
from globals import *
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from IPython.display import display, HTML
from collections import defaultdict
from sklearn.metrics import roc_auc_score
from plotly.tools import make_subplots

layers = np.arange(0,26)

# %%
token_dict = pickle.load(open(f'final/pos/token_dict.pkl','rb'))
pos_dict = pickle.load(open(f'final/pos/pos_dict.pkl','rb'))

all_pos = []
for contexti in range(2000):
    all_pos.extend(pos_dict[contexti])
all_pos = np.array(all_pos)

#%% ACROSS LAYERS PLOT FOR POS 
N = len(all_pos)

pos_tag_dict = {
    'adverb': ['adv'],
    'adjective': ['adj'],
    'pronoun': ['pronoun'],
    'preposition': ['prepos'],
    'conjunction': ['conjunction'],
    'article': ['article'],
}

for pos_tag, pos_tags in pos_tag_dict.items():
    pos = np.any([all_pos == tag for tag in pos_tags], axis=0).astype(int)
    aucs_dict = {}
    for layer in tqdm(layers, desc=f'Processing {pos_tag}'):
        aucs_dict[layer] = {}
        freqs, saes = get_mydata(layer,freqs=True,saes=True)
        elephants = get_elephants_thres(freqs,0.1)
        all_data = pickle.load(open(f'final/pos/layer{layer}_all_data_combined.pkl','rb'))
        all_acts = []
        for contexti in range(2000):
            sparse_acts = all_data[contexti]
            all_acts.append(sparse_acts.to_dense()[1:])
        all_acts = torch.cat(all_acts,dim=0)
        
        for e in elephants:
            acts = (all_acts[:,e] > 0).int()
            auc = roc_auc_score(acts,pos)
            aucs_dict[layer][e] = auc
    with open(f'final/pos/aucs_dict_{pos_tag}.pkl','wb') as f:
        pickle.dump(aucs_dict,f)

#%%
pos_tag = 'conjunction'
fig = make_subplots(rows=1, cols=1)

aucs_dict = load_pickle(f'final/pos/aucs_dict_{pos_tag}.pkl')
x_vals = []
y_vals = []
hover_texts = []

for layer in aucs_dict:
    for e in aucs_dict[layer]:
        auc = aucs_dict[layer][e]
        x_vals.append(layer)
        y_vals.append(auc)
        hover_texts.append(f'Layer: {layer}<br>Feature: {e}<br>AUC: {auc:.3f}')

fig.add_trace(
    go.Scatter(
        x=x_vals,
        y=y_vals,
        mode='markers',
        marker=dict(size=2, color='#636EFA'),
        text=hover_texts,
        hoverinfo='text',
        showlegend=False,
    )
)

fig.add_shape(
    type="line",
    x0=min(layers), 
    x1=max(layers),
    y0=0.5,
    y1=0.5,
    line=dict(color="grey", width=1)
)

fig.update_xaxes(title_text='Layer')
fig.update_yaxes(title_text='AUC')

fig.update_layout(
    width=300,
    height=200,
    margin=dict(l=50, r=10, t=10, b=10),
    showlegend=False,
    yaxis=dict(range=[0.05, 0.95])
)

fig.show()

fig.write_image(f"*PLOTS/pos_auc_{pos_tag}.pdf", scale=20)
