# %%
import os
import sys
import pandas as pd
import torch
from tqdm import tqdm
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import re
import plotly

script_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.join(script_dir, '..')
sys.path.append(script_dir)
sys.path.append(project_dir)

from utils import load_sae, principal_angles_from_vectors

# %%
torch.set_grad_enabled(False)

device = 'cuda' if torch.cuda.is_available() else 'mps'
cache_dir = None

model_name = 'gemma-2-2b'

dataset = 'stas/c4-en-10k'

density_dict = {}
encoders = {}
decoders = {}
for layer_idx in range(26):
    reference_model_name = 'gpt2-small' if 'gpt2' in model_name else model_name
    sae = load_sae(reference_model_name, width='16k', layer_idx=layer_idx, location='res', device=device)

    # load densities of reference SAE
    file_model_name = reference_model_name.replace('google/', '')
    path = f"{project_dir}/data/frequencies/{file_model_name}/{layer_idx}.json"

    df = pd.read_json(path)

    # make 'feature_idx' column the index
    df = df.set_index('feature_idx')

    # remove rows with duplicate feature_idx
    df = df[~df.index.duplicated(keep='first')]

    # sort by feature_idx
    df = df.sort_index()

    densities = torch.tensor(df['frac_nonzero'].values)
    density_dict[layer_idx] = densities
    encoder = sae.W_enc.T
    decoder = sae.W_dec
    encoder_bias = sae.b_enc
    decoder_bias = sae.b_dec    
    encoders[layer_idx] = encoder
    decoders[layer_idx] = decoder


# %%
# =============================================================================
# Histogram of composition of nullspace
# =============================================================================
U = torch.load(f'{project_dir}/data/U/{model_name}.pt')
print(f'Loaded U from ./data/U/{model_name}.pt')
# %%
nullspace_threshold = 10
layer_idx = 25
encoder = encoders[layer_idx]
densities = density_dict[layer_idx]

proj_on_U = encoder @ U
nullspace_comp = proj_on_U[:, -nullspace_threshold:].norm(dim=1) / proj_on_U.norm(dim=1)
nullspace_comp = nullspace_comp.cpu()
density_above_01 = densities > 0.1
# make histogram of nullspace component
fig = px.histogram(x=nullspace_comp, log_y=True, labels={'x': f'ρ<sub>{nullspace_threshold}</sub>'}, color=density_above_01, color_discrete_sequence=['#636EFA', '#EF553B'], marginal='box', nbins=100)

# update y label
fig.update_layout(yaxis_title='Count')
# update legend
fig.for_each_trace(lambda t: t.update(name='>0.1' if t.name == 'True' else '<0.1'))
fig.update_layout(legend_title_text='Density')
# add title
fig.update_layout(title=f'Norm in <b>W</b><sub>U</sub> Quasi-Null Space')

# move legend to bottom
fig.update_layout(legend=dict(yanchor="bottom", y=-0.4, xanchor="left", x=0.1, orientation='h'))
fig.show()


# %%
# =========================================================================
# antipodality score vs density
# =========================================================================

normalized_W_enc = encoder / encoder.norm(dim=1, keepdim=True)
normalized_W_dec = decoder / decoder.norm(dim=1, keepdim=True)

# Compute pairwise cosine similarities
cosine_similarities_enc = torch.mm(normalized_W_enc, normalized_W_enc.t())

# set the diagonal to 0
cosine_similarities_enc = cosine_similarities_enc.fill_diagonal_(0)
most_similar_enc_latents = cosine_similarities_enc.abs().argmax(dim=1)
max_enc_cos_sim = cosine_similarities_enc[range(cosine_similarities_enc.shape[0]), most_similar_enc_latents]
cosine_similarities_dec = torch.mm(normalized_W_dec, normalized_W_dec.t())

# set the diagonal to 0
cosine_similarities_dec = cosine_similarities_dec.fill_diagonal_(0)
most_similar_dec_latents = cosine_similarities_dec.abs().argmax(dim=1)
max_dec_cos_sim = cosine_similarities_dec[range(cosine_similarities_dec.shape[0]), most_similar_dec_latents]
product = cosine_similarities_enc * cosine_similarities_dec

# get index of latent with highest product
highest_product_latents = product.argmax(dim=1)
max_product = product[range(product.shape[0]), highest_product_latents]

# %%
# make scatter plot of scores vs density
layer_density = torch.tensor(densities)
scores = max_product.cpu()

fig = px.scatter(x=scores, y=layer_density, opacity=0.7, color_continuous_scale='Bluered', labels={'x': 'Antipodality Score (<i>s</i>)', 'y': 'Density', 'color': 'cos'}, marginal_x='histogram', marginal_y='histogram')

fig.update_layout(title=f'Antipodality Score vs. Density')

fig.show()

# %%
# ==============================================================================
# ablation plots
# ==============================================================================
num_docs = 300
model_name = 'google/gemma-2-2b'
out_path = f'{project_dir}/data/ablations/{model_name}/{model_name}_pretrained/{dataset.split("/")[1]}_{num_docs}/'
df_path=f'{out_path}/{dataset.split("/")[1]}_{num_docs}.feather'
df = pd.read_feather(df_path)

# filter out rows in which token_id is 2
df = df[df['tokens'] != 2]

# subsample to 10000
limit = 10000
df = df[:limit]

# %%
columns = df.columns

latent_indices = [int(re.search(r'\d+', col).group()) for col in columns if re.search(r'\d+', col)]
latent_indices = list(set(latent_indices))
latent_indices

te_entropy = {}
de_entropy = {}
delta_entropy = {}
one_minus_te_over_de = {}
for i in latent_indices:
    te_entropy[i] = df[f'ablated_entropies_{i}'] - df[f'entropies']
    de_entropy[i] = df[f'ablated_entropies_fixed_{i}'] - df[f'entropies']
    delta_entropy[i] = (te_entropy[i] - de_entropy[i]) 
    one_minus_te_over_de[i] = (1 - (de_entropy[i].abs().mean() / te_entropy[i].abs().mean()))
# %%
# make boxplots of delta_entropy for each latent
fig = go.Figure()
random_indices_deltas = []
filtered_latents = []
for i in latent_indices:
    if nullspace_comp[i] > 0.3 and densities[i] > 0.1:
        if highest_product_latents[i] in filtered_latents:
            continue
        fig.add_trace(go.Box(
            y=delta_entropy[i], 
            name=f'{i}' if i !=13749 and i !=14325 else f'<b>{i}</b>', 
            boxmean='sd',
            marker_color=plotly.colors.qualitative.Plotly[0],
            
        ))
        filtered_latents.append(i)
    if nullspace_comp[i] < 0.2:
        random_indices_deltas.extend(delta_entropy[i].tolist())

fig.add_trace(go.Box(
    y=random_indices_deltas, 
    name=f'Random', 
    boxmean='sd',
    marker_color=plotly.colors.qualitative.Plotly[1]  
))

# remove legend
fig.update_layout(showlegend=False)

# Add a vertical dashed line between the feature boxplots and the Random boxplot
fig.update_layout(
    shapes=[
        # Vertical line 
        dict(
            type="line",
            xref="x",
            yref="paper",
            x0=len(filtered_latents) - 0.5,  # Position right before the Random boxplot
            y0=0,    # Line start at bottom
            x1=len(filtered_latents) - 0.5,  # Same x position
            y1=1,    # Line end at top
            line=dict(
                color="Black",
                width=1.5,
                dash="dash",
            )
        )
    ]
)

# add title
fig.update_layout(title=f'(b) Changes in Entropy Upon Ablation')
fig.update_layout(legend_title_text='Latent Type')
fig.update_xaxes(title_text='Latent Index')
fig.update_yaxes(title_text='ΔEntropy')
fig.show()


# %%
# =============================================================================
# rotation of the subspace
# =============================================================================

avg_principal_angles = torch.zeros(((26, 26)))
median_principal_angles = torch.zeros(((26, 26)))
threshold = 0.2

dense = True

subspaces = {}
for layer_idx in range(26):
    # get the dense subspace
    if dense:
        dense_idx = (density_dict[layer_idx] > threshold).nonzero(as_tuple=True)[0]
    else:
        # sample 100 latents with density < 0.1
        dense_idx = (density_dict[layer_idx] < threshold).nonzero(as_tuple=True)[0]
        dense_idx = torch.randint(0, len(dense_idx), (100,))
    dense_encoder = encoders[layer_idx][dense_idx]
    subspaces[layer_idx] = dense_encoder

pbar = tqdm(total=26*26, desc='Computing principal angles')
for layer_idx in range(26):

    for layer_idx2 in range(26):

        # compute the principal angles
        angles = principal_angles_from_vectors(
            subspaces[layer_idx].T,
            subspaces[layer_idx2].T,
        )
        avg_principal_angles[layer_idx, layer_idx2] = angles.mean().item()
        median_principal_angles[layer_idx, layer_idx2] = angles.median().item()
        pbar.update(1)
pbar.close()

# convert to degrees
avg_principal_angles = avg_principal_angles * 180 / np.pi
median_principal_angles = median_principal_angles * 180 / np.pi
# %%
# make heatmap of avg_principal_angles
fig = px.imshow(
    median_principal_angles,
    color_continuous_scale='Blues',
    labels=dict(x="Layer index", y="Layer index", color="Principal angle"),
    x=[str(i) for i in range(26)],
    y=[str(i) for i in range(26)],
)
fig.update_layout(
    xaxis_title='Layer Index',
    yaxis_title='Layer Index',
    coloraxis_colorbar=dict(
        title="Median<br>Principal<br>Angle",
        # tickvals=[0, 0.5, 1],
        # ticktext=["0", "0.5", "1"],
    ),
)
fig.update_traces(
    hovertemplate='Layer index: %{x}<br>Principal angle: %{z:.2f}<extra></extra>',
)

fig.show()


# %%
# =============================================================================
# cos sim of densest latents
# =============================================================================
layer_idx = 25
encoder = encoders[layer_idx]
densities = density_dict[layer_idx]
# get the indices of the top 50 densest latents
dense_idx = torch.argsort(densities, descending=True)[:50]
dense_encoder = encoder[dense_idx]

# %%
# Compute the pairwise cosine similarities
cosine_similarities = torch.nn.functional.cosine_similarity(dense_encoder, dense_encoder.unsqueeze(1), dim=2)

# Create a mask for the upper triangle (including the diagonal)
mask = np.triu(np.ones_like(cosine_similarities.cpu().numpy()), k=1)  # k=0 includes diagonal, k=1 would exclude it
cosine_similarities_masked = np.ma.array(cosine_similarities.cpu().numpy(), mask=mask)

# set nans on the masked values
cosine_similarities_masked = np.where(cosine_similarities_masked.mask, np.nan, cosine_similarities_masked)

# Make heatmap of cosine similarities (masked)
fig = px.imshow(
    cosine_similarities_masked,
    color_continuous_scale='Picnic',
    labels=dict(x="Latent index", y="Latent index", color="Cosine similarity"),
    x=[str(i) for i in range(len(dense_idx))],
    y=[str(i) for i in range(len(dense_idx))],
    origin="upper"
)

# set background color to white
fig.update_layout(
    plot_bgcolor='white',
    paper_bgcolor='white',
)

fig.update_layout(
    xaxis_title='Latent Index',
    yaxis_title='Latent Index',
    coloraxis_colorbar=dict(
        title="Encoder<br>Cosine<br>Similarity",
    ),
)

fig.update_traces(
    hovertemplate='Latent index: %{x}<br>Latent index: %{y}<br>Cosine similarity: %{z:.2f}<extra></extra>',
)

# Add title
fig.update_layout(title=f'Cosine Similarity of Dense Latents')

# set range of color scale
fig.update_coloraxes(cmin=-1, cmax=1)

fig.show()
# %%
