#%%
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import argparse
import torch
from openai import OpenAI
from globals import *
import datasets
import pickle
from tqdm.auto import tqdm

OPENROUTER_API_KEY = 'INSERT API KEY HERE'
torch.set_grad_enabled(False)
device = torch.device("cuda:0")
device_map = {'': 0}
model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-2b",
        device_map=device_map,
        torch_dtype=torch.bfloat16,
        token='INSERT TOKEN HERE'
    ).to(device)
DATASET = "togethercomputer/RedPajama-Data-1T-Sample"
dataset = datasets.load_dataset(DATASET, split="train")
dataset = dataset.shuffle(seed=42)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b",token='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')
client = OpenAI(
  base_url="https://openrouter.ai/api/v1",
  api_key=OPENROUTER_API_KEY
)

#%%
parser = argparse.ArgumentParser()
parser.add_argument('--layer', type=int, required=True, help='Layer number to process')
args = parser.parse_args()
layer = args.layer

#%%
gpt_model = "google/gemini-2.5-flash-preview"
n_data = 200
max_input_tokens = 400
max_output_tokens = 100
min_tokens = 5
steer_scale = 2
to_print = False
do_sample = False

#%%
texts = []
j = 0
for item in dataset:
    inputs = tokenizer(
        item['text'],
        return_tensors="pt",
        add_special_tokens=True, 
        max_length=1024,
        truncation=True
    )
    if len(inputs['input_ids'][0]) >= 1024:
        texts.append(item['text'])
    if len(texts) >= n_data:
        break

def make_hook_fn(U, add_vector):
    def hook_fn(module, input, output):
        hidden_states, *cache = output
        last_token_h_centered = hidden_states[:, -1, :] - b_dec
        erase_proj = last_token_h_centered @ U @ U.T
        last_token_h_subtracted = last_token_h_centered - erase_proj
        last_token_h_steered = last_token_h_subtracted + add_vector
        hidden_states[:, -1, :] = last_token_h_steered + b_dec
        return (hidden_states, *cache)
    return hook_fn

def get_feature_alignment(completion, f1_activations, f2_activations):
    prompt = f"""You are an annotation assistant. Given two sets of example texts, and a new text, decide which set the new text is more similar to.

Set 1 examples:
{chr(10).join(f1_activations)}

Set 2 examples:
{chr(10).join(f2_activations)}

New text:
{completion}

Answer with just: 1, 2, or UNCLEAR"""
    try:
        response = client.chat.completions.create(
            model=gpt_model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            max_tokens=max_output_tokens,
        )
        evaluation = response.choices[0].message.content.strip()
    except Exception as e:
        print(f"❌ OpenAI API Error: {e}")
        evaluation = ""
    return completion, evaluation

def get_input_up_to_newline(text, tokenizer, max_input_tokens):
    tokens = tokenizer.encode(text, add_special_tokens=True)
    if len(tokens) <= max_input_tokens:
        return text, tokens
    for i in range(max_input_tokens, len(tokens)):
        decoded = tokenizer.decode(tokens[i], skip_special_tokens=True)
        if '\n' in decoded:
            end_idx = i + 1
            break
    else:
        end_idx = max_input_tokens
    input_text = tokenizer.decode(tokens[:end_idx], skip_special_tokens=True)
    return input_text, tokens[:end_idx]

#%%
print(f'------- STARTING LAYER {layer} -------')
freqs, saes = get_mydata(layer,freqs=True,saes=True)
elephants = print_elephants_with_pairs(freqs,saes,thres=0.2,flat=False)
elephants = [p for p in elephants if len(p) == 2]

sae = load_sae_lens(f'layer_{layer}/width_16k/canonical', 'gemma_2_2b').to(device)
saes = sae.W_dec.detach()
b_dec = sae.b_dec.detach()

for feature_pair in elephants:
    results = {}
    for data_idx in tqdm(range(n_data),desc=f'FEATURE PAIR {feature_pair}'):
        text = texts[data_idx]
        input_text, input_tokens = get_input_up_to_newline(text, tokenizer, max_input_tokens)
        encoding = tokenizer(
            input_text,
            return_tensors="pt", 
            add_special_tokens=True,
            max_length=len(input_tokens),
            truncation=False
        ).to(device)
        inputs = encoding['input_ids']
        token_strs = [tokenizer.decode(x) for x in inputs[0]]

        model_acts = model(**encoding, 
                        labels=inputs, 
                        output_hidden_states=True,
                        ).hidden_states[layer+1][0]
        sae_acts = sae.encode(model_acts)

        centered_model_acts = model_acts - b_dec
        proj_feature_1 = centered_model_acts @ saes[feature_pair[0]]
        proj_feature_2 = centered_model_acts @ saes[feature_pair[1]]
        proj_feature_1[0] = 0
        proj_feature_2[0] = 0
        max_proj_feature_1 = abs(proj_feature_1.max())
        max_proj_feature_2 = abs(proj_feature_2.max())
        max_proj_features = max(max_proj_feature_1, max_proj_feature_2)

        f1_og_acts = sae_acts[:, feature_pair[0]]
        f2_og_acts = sae_acts[:, feature_pair[1]]

        f1_og_acts[0] = 0
        f2_og_acts[0] = 0
        f1_acts = (f1_og_acts > 0).int()  
        f2_acts = (f2_og_acts > 0).int()

        last_active = None
        for f1, f2 in zip(f1_acts, f2_acts):
            if f1:
                last_active = "1"
            elif f2:
                last_active = "2"
        
        f1_activations = []
        f2_activations = []
        current_f1_phrase = []
        current_f2_phrase = []
        
        for token, f1, f2 in zip(token_strs, f1_acts, f2_acts):
            if f1:
                current_f1_phrase.append(token)
            elif current_f1_phrase:
                if len(current_f1_phrase) > min_tokens:
                    f1_activations.append("".join(current_f1_phrase))
                current_f1_phrase = []
                
            if f2:
                current_f2_phrase.append(token)
            elif current_f2_phrase:
                if len(current_f2_phrase) > min_tokens:
                    f2_activations.append("".join(current_f2_phrase))
                current_f2_phrase = []
                
        if current_f1_phrase:
            if len(current_f1_phrase) > min_tokens:
                f1_activations.append("".join(current_f1_phrase))
        if current_f2_phrase:
            if len(current_f2_phrase) > min_tokens:
                f2_activations.append("".join(current_f2_phrase))
            
        if len(f1_activations) == 0 or len(f2_activations) == 0:
            if to_print:
                print(f"Skipping context {data_idx} due to insufficient activations: F1={len(f1_activations)}, F2={len(f2_activations)}")
            continue

        results[data_idx] = {}
        results[data_idx]['text'] = text
        results[data_idx]['last_active'] = last_active
        results[data_idx]['f1_examples'] = f1_activations
        results[data_idx]['f2_examples'] = f2_activations

        formatted = "F1 activations:\n" + "\n".join(f1_activations) + "\n\nF2 activations:\n" + "\n".join(f2_activations)

        generated_ids = model.generate(
            **encoding,
            max_new_tokens=max_output_tokens,
            do_sample=do_sample)
        new_tokens = generated_ids[0][encoding['input_ids'].shape[1]:]
        completion, evaluation = get_feature_alignment(
            tokenizer.decode(new_tokens, skip_special_tokens=True),
            f1_activations,
            f2_activations
        )

        if to_print:
            print('UNSTEERED COMPLETION')
            print(completion)
            print(evaluation)

        results[data_idx]['steer_none_completion'] = completion
        results[data_idx]['steer_none_evaluation'] = evaluation

        for i, feature in enumerate(feature_pair):
            if i == 0:
                steer_feature = feature_pair[0]
                friend_feature = feature_pair[1]
                acts = f1_og_acts
            elif i == 1:
                steer_feature = feature_pair[1]
                friend_feature = feature_pair[0]
                acts = f2_og_acts

            steer_direction = saes[steer_feature]
            friend_direction = saes[friend_feature]
            U = torch.stack([steer_direction, friend_direction])
            U = torch.linalg.qr(U.T)[0] 
            add_vector = max_proj_features * saes[steer_feature] * steer_scale

            handle = model.model.layers[layer].register_forward_hook(make_hook_fn(U, add_vector))
            if to_print:
                print(f'STEERED COMPLETION {i+1}')
            generated_ids = model.generate(
                    **encoding,
                    max_new_tokens=max_output_tokens,
                    do_sample=do_sample)
            handle.remove()

            new_tokens = generated_ids[0][encoding['input_ids'].shape[1]:]
            completion, evaluation = get_feature_alignment(
                    tokenizer.decode(new_tokens, skip_special_tokens=True),
                    f1_activations,
                    f2_activations
            )
            
            if to_print:
                print(completion)
                print(evaluation)
            results[data_idx][f'steer_{i+1}_completion'] = completion
            results[data_idx][f'steer_{i+1}_evaluation'] = evaluation

        if to_print:
            html_output = highlight_tokens_scaled_html(token_strs, f1_og_acts, f2_og_acts)
            display(HTML(html_output))

    with open(f'autointerp/layer{layer}_{feature_pair}.pkl', 'wb') as f:
        pickle.dump(results, f)
    print(f'------- FINISHED FEATURE PAIR {feature_pair} -------')
print(f'------- FINISHED LAYER {layer} -------')
