#%%
##### RUN WITH: 
#streamlit run interactive.py --server.address 0.0.0.0 --server.port 8501 --server.enableCORS false --server.enableXsrfProtection false

import streamlit as st
import torch
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pickle
from sae_lens import SAE
from globals import *

tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b",token='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')

#%%
@st.cache_resource
def load_inputs():
    return torch.load("gemmascope/all_inputs.pt")

@st.cache_resource
def load_layer_data(layer):
    freqs, saes, elephant_acts_dict = get_mydata(layer, freqs=True, saes=True, elephant_acts=True)
    elephants = print_elephants_with_pairs(freqs, saes, 0.1, flat=False)
    return freqs, saes, elephant_acts_dict, elephants

def decode_tokens(inputs):
    return [tokenizer.decode(x) for x in inputs]

def show_html(html):
    st.components.v1.html(html, height=1000, width=900, scrolling=True)

st.set_page_config(layout="wide")
layer = st.sidebar.slider("Select Layer", 0, 25, 5)
contextj = st.sidebar.number_input("Select Context Number", 0, 5000, 0)

all_inputs = load_inputs()
freqs, saes, elephant_acts_dict, elephant_choices = load_layer_data(layer)

elephant_label_map = {}
for pair in elephant_choices:
    if len(pair) == 1:
        label = f"#{pair[0]}, f={freqs[pair[0]]:.3f}"
        elephant_label_map[label] = (pair[0],None)
    else:
        label = f"#{pair[0]}, f={freqs[pair[0]]:.3f} and #{pair[1]}, f={freqs[pair[1]]:.3f}"
        elephant_label_map[label] = pair

elephant_label = st.sidebar.selectbox("Select Elephant Pair", list(elephant_label_map.keys()))
elephant1, elephant2 = elephant_label_map[elephant_label]

inputs = all_inputs[contextj]
token_strs = decode_tokens(inputs)

if elephant2 is None:
    act_vector = elephant_acts_dict[elephant1].to_dense()[contextj*1024:(contextj+1)*1024]
    act_vector = remove_bos_acts(act_vector)
    scaled_act_vector = scale_acts(act_vector)
    html_output = highlight_tokens_scaled_html(token_strs, scaled_act_vector, fixed_width=600)
    st.markdown(f"### Layer {layer} | Text ID {contextj} | Elephant {elephant1}")
    show_html(html_output)
else:
    act_vector1 = elephant_acts_dict[elephant1].to_dense()[contextj*1024:(contextj+1)*1024]
    act_vector2 = elephant_acts_dict[elephant2].to_dense()[contextj*1024:(contextj+1)*1024]
    act_vector1 = remove_bos_acts(act_vector1)
    act_vector2 = remove_bos_acts(act_vector2)
    scaled_act_vector1 = scale_acts(act_vector1)
    scaled_act_vector2 = scale_acts(act_vector2)
    html_output = highlight_tokens_scaled_html(token_strs, scaled_act_vector1, scaled_act_vector2, fixed_width=600)
    st.markdown(f"### Layer {layer} | Text ID {contextj} | Elephants {elephant1}, {elephant2}")
    show_html(html_output)


# %%
