# %%
import os
import sys
import pandas as pd
import torch
from transformer_lens.utils import load_dataset, tokenize_and_concatenate
from tqdm import tqdm
from torch.utils.data import DataLoader
from functools import partial
import numpy as np

script_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.join(script_dir, '..')
sys.path.append(script_dir)
sys.path.append(project_dir)

from utils import load_tokenized_data, load_sae, adjust_vectors, load_model_from_tl_name
import transformer_lens.utils as tlutils

# set seed
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# %%
torch.set_grad_enabled(False)

device = 'cuda' if torch.cuda.is_available() else 'mps'
cache_dir = None

# load model
model_name = 'google/gemma-2-2b'
dtype = torch.bfloat16
model, tokenizer = load_model_from_tl_name(model_name, cache_dir=cache_dir, device=device)
layer_idx = 25

# %%
# load SAE
batch_size = 16
dataset = 'stas/c4-en-10k'
num_docs = 550

ckpt_name = f'{model_name}_pretrained'
reference_model_name = 'gpt2-small' if 'gpt2' in model_name else model_name
sae = load_sae(reference_model_name, width='16k', layer_idx=layer_idx, location='res', device=device)
# load densities of reference SAE
file_model_name = reference_model_name.replace('google/', '')
path = f"{project_dir}/data/frequencies/{file_model_name}/{layer_idx}.json"

df = pd.read_json(path)

# make 'feature_idx' column the index
df = df.set_index('feature_idx')

# remove rows with duplicate feature_idx
df = df[~df.index.duplicated(keep='first')]

# sort by feature_idx
df = df.sort_index()

densities = torch.tensor(df['frac_nonzero'].values)

encoder = sae.W_enc.T
decoder = sae.W_dec
encoder_bias = sae.b_enc
decoder_bias = sae.b_dec

U = torch.load(f'{project_dir}/data/U/{model_name.replace("google/", "")}.pt', map_location=device)
print(f'Loaded U from ./data/U/{model_name}.pt')

nullspace_threshold = 10
proj_on_U = encoder @ U
nullspace_comp = proj_on_U[:, -nullspace_threshold:].norm(dim=1) / proj_on_U.norm(dim=1)
nullspace_comp = nullspace_comp.cpu()

# get indices of the latents with nullspace comp > 0.2
nullspace_latent_indices = torch.where(nullspace_comp > 0.2)[0].tolist()

# sample at random len(nullspace_latent_indices) indices from the remaining latents
remaining_latent_indices = [i for i in range(encoder.shape[0]) if i not in nullspace_latent_indices]
remaining_latent_indices = np.random.choice(remaining_latent_indices, len(nullspace_latent_indices), replace=False).tolist()
latent_indices = nullspace_latent_indices + remaining_latent_indices
print(f'Nullspace latents: {nullspace_latent_indices}')
print(f'Remaining latents: {remaining_latent_indices}')


# %%

# load data
data = load_dataset(dataset, split='train')
first_1k = data.select([i for i in range(0, num_docs)])
tokenized_data = tokenize_and_concatenate(first_1k, tokenizer, max_length=64, column_name='text')
# tokenized_data = load_tokenized_data(cache_dir, dataset, 50, context_length=64, model_name=model_name)
 

# %%
# Dictionary to store activations
activations_dict = {}

# Define a hook that ablates the given latents
def ablation_hook(value, hook, latent_weight, ablation_value):
    u = latent_weight / latent_weight.norm()
    value = adjust_vectors(value, u, ablation_value)
    return value

def ln_final_scale_hook(value, hook, scale_values):
    #scale hook (batch, seq, 1)
    value = scale_values
    return value
    

out_path = f'{project_dir}/data/ablations/{model_name}/{ckpt_name}/{dataset.split("/")[1]}_{num_docs}/'
# make dir if it doesn't exist
os.makedirs(out_path, exist_ok=True)
df_path=f'{out_path}/{dataset.split("/")[1]}_{num_docs}.feather'

columns = {}
stored_entries = []
dl = DataLoader(tokenized_data, batch_size=batch_size, shuffle=False)

# %%
with torch.no_grad():
    for i, batch in enumerate(tqdm(dl)):
        tokens = batch["tokens"]

        logits, cache = model.run_with_cache(tokens.to(device))

        top_logits = logits.max(dim=-1).values.cpu().numpy()
        preds = logits.argmax(dim=-1).cpu().numpy()
        ln_scales = cache["ln_final.hook_scale"]

        # compute entropy of logits
        probs = logits.softmax(dim=-1)
        entropies = -torch.sum(probs * torch.log(probs), dim=-1).cpu().numpy()
        del cache

        entry = {}
        entry['tokens'] = tokens.cpu().numpy()
        entry['top_logits'] = top_logits
        entry['preds'] = preds
        entry['ln_scales'] = ln_scales.squeeze().cpu().numpy()
        entry['entropies'] = entropies

        for latent_idx in latent_indices:
            latent_weight = decoder[latent_idx]
            # ablation_value = -encoder_bias[latent_idx]
            ablation_value = decoder_bias @ latent_weight / latent_weight.norm()

            # perform normal ablation
            hooks = [(tlutils.get_act_name('resid_post', layer_idx), partial(ablation_hook, latent_weight=latent_weight, ablation_value=ablation_value))]
            with model.hooks(fwd_hooks=hooks):
                ablated_logits, cache = model.run_with_cache(tokens.to(device))

            ablated_top_logits = ablated_logits.max(dim=-1).values.cpu().numpy()
            ablated_preds = ablated_logits.argmax(dim=-1).cpu().numpy()
            ablated_ln_scales = cache["ln_final.hook_scale"].squeeze().cpu().numpy()

            ablated_probs = ablated_logits.softmax(dim=-1)
            ablated_entropies = -torch.sum(ablated_probs * torch.log(ablated_probs), dim=-1).cpu().numpy()
            del cache

            entry[f'ablated_ln_scales_{latent_idx}'] = ablated_ln_scales
            entry[f'ablated_top_logits_{latent_idx}'] = ablated_top_logits
            entry[f'ablated_preds_{latent_idx}'] = ablated_preds
            entry[f'ablated_entropies_{latent_idx}'] = ablated_entropies

            # perform ablation with fixed ln scale
            hooks = [(tlutils.get_act_name('resid_post', layer_idx), partial(ablation_hook, latent_weight=latent_weight, ablation_value=ablation_value)), ('ln_final.hook_scale', partial(ln_final_scale_hook, scale_values=ln_scales))]
            with model.hooks(fwd_hooks=hooks):
                ablated_logits_fixed, cache = model.run_with_cache(tokens.to(device))
            ablated_top_logits_fixed = ablated_logits_fixed.max(dim=-1).values.cpu().numpy()
            ablated_preds_fixed = ablated_logits_fixed.argmax(dim=-1).cpu().numpy()
            ablated_ln_scales_fixed = cache["ln_final.hook_scale"].squeeze().cpu().numpy()
            ablated_probs_fixed = ablated_logits_fixed.softmax(dim=-1)
            ablated_entropies_fixed = -torch.sum(ablated_probs_fixed * torch.log(ablated_probs_fixed), dim=-1).cpu().numpy()
            del cache

            entry[f'ablated_ln_scales_fixed_{latent_idx}'] = ablated_ln_scales_fixed
            entry[f'ablated_top_logits_fixed_{latent_idx}'] = ablated_top_logits_fixed
            entry[f'ablated_preds_fixed_{latent_idx}'] = ablated_preds_fixed
            entry[f'ablated_entropies_fixed_{latent_idx}'] = ablated_entropies_fixed
        
        stored_entries.append(entry)

        if i % 10 == 0:
            # Initialize a dictionary to hold all combined data
            combined_columns = {}
            
            # Loop through all stored entries
            for entry_dict in stored_entries:
                for k, v in entry_dict.items():
                    if k not in combined_columns:
                        combined_columns[k] = []
                    combined_columns[k].append(v)
            
            # Concatenate all arrays for each key
            for k in combined_columns:
                combined_columns[k] = np.concatenate(combined_columns[k], axis=0).flatten()
            
            # Create the DataFrame with the combined data
            df = pd.DataFrame(combined_columns)
            df.to_feather(df_path)

        
    # After the loop, do the final save
    combined_columns = {}
    
    for entry_dict in stored_entries:
        for k, v in entry_dict.items():
            if k not in combined_columns:
                combined_columns[k] = []
            combined_columns[k].append(v)
    
    for k in combined_columns:
        combined_columns[k] = np.concatenate(combined_columns[k], axis=0).flatten()
    
    df = pd.DataFrame(combined_columns)
    df.to_feather(df_path)

# %%
