import os
from typing import List, Optional, Tuple

import torch
from huggingface_hub import login
from nnsight import LanguageModel
from sae_lens import SAE  # pip install sae-lens
from transformers import AutoTokenizer
# Configuration constants
HF_TOKEN = os.getenv("HF_TOKEN", "")
MODEL_NAME = "google/gemma-2-9b-it"
SAE_NAME = "gemma-scope-9b-pt-res"
LAYER_ID = 26
TOP_K = 1000
DEVICE = "cuda"

# Color constants for terminal output
class Colors:
    RED = '\033[91m'
    ORANGE = '\033[93m'
    GREEN = '\033[92m'
    BLUE = '\033[94m'
    RESET = '\033[0m'

# Feature indices to highlight
HIGHLIGHT_FEATURES = {
    13751: Colors.RED,
    15261: Colors.ORANGE,
    9591: Colors.GREEN,
}


def hids2feats(ids: torch.Tensor, hids: torch.Tensor, sae, llm, top_k: int = 1000) -> None:
    """Display features for given hidden states with highlighting specific features."""
    feature_acts = sae.encode(hids.to(DEVICE))
    feature_acts = feature_acts.squeeze(0)
    ids = ids.squeeze(0)
    
    for pos_id,(input_id, feats) in enumerate(zip(ids, feature_acts)):
        top_feat_values, top_feat_indices = feats.topk(top_k, largest=True, sorted=True)
        token = llm.tokenizer.decode(input_id)
        top_feats = [f"{pos_id}:{token}"]
        
        for feat_value, feat_index in zip(top_feat_values, top_feat_indices):
            if feat_value > 0:
                feat_idx = feat_index.item()
                feat_val = round(feat_value.item(), 2)
                
                # Highlight specific features
                if feat_idx in HIGHLIGHT_FEATURES:
                    color = HIGHLIGHT_FEATURES[feat_idx]
                    top_feats.append(f"{color}({feat_idx},{feat_val}){Colors.RESET}")
                else:
                    # Uncomment to show all features
                    # top_feats.append(f"({feat_idx},{feat_val})")
                    pass
        
        print(' '.join(top_feats))
def trace_feature_act(prompt: str, llm, sae, layer_id: int = LAYER_ID) -> None:
    """Trace and display features for a given prompt."""
    messages = [{"role": "user", "content": prompt}]
    input_ids = tokenizer.apply_chat_template(
        messages, 
        return_tensors="pt", 
        return_dict=True, 
        add_generation_prompt=True
    ).input_ids.to(DEVICE)
    
    with llm.generate(max_new_tokens=128, do_sample=False) as tracer:
        with tracer.invoke(input_ids) as invoker:
            layer_outputs = llm.model.layers[layer_id].output.save()
            hids = layer_outputs[0].save()
        output = llm.generator.output.save()
    
    decode_tokens = llm.tokenizer.batch_decode(output)
    print("Original prompt and generation:")
    print(decode_tokens[0])
    print(" * " * 10)
    print("Feature activations:")
    hids2feats(input_ids, hids, sae, llm, TOP_K)
def steer_feature(prompt: str, sae, feat_list: List[int], strengths: List[int], layer_id: int = LAYER_ID) -> None:
    """Steer model generation by activating specific features."""
    messages = [{"role": "user", "content": prompt}]
    input_ids = tokenizer.apply_chat_template(
        messages, 
        return_tensors="pt", 
        return_dict=True, 
        add_generation_prompt=True
    ).input_ids.to(DEVICE)
    
    # Map feature IDs to semantic labels
    feature_names = {9591: "Russia", 13751: "UK", 15261: "Speak Chinese"}
    labeled_features = [f"{feature_names.get(f, f)}({f})" for f in feat_list]
    print(f"Force activating features: {labeled_features}")
    
    with llm.generate(max_new_tokens=128, do_sample=False) as tracer:
        with tracer.invoke(input_ids) as invoker:
            for feat_id,strength in zip(feat_list,strengths):
                if 0 <= feat_id < sae.W_dec.shape[0]:
                    feature = sae.W_dec[feat_id]
                    llm.model.layers[layer_id].output[0][:, -1, :] += feature * strength
                else:
                    print(f"Warning: Feature ID {feat_id} is out of bounds")
        steered = llm.generator.output.save()
    
    steered_decode = llm.tokenizer.batch_decode(steered)
    print(steered_decode[0])
    print(" * " * 10)
def initialize_models() -> tuple[LanguageModel, AutoTokenizer, SAE]:
    """Initialize language model, tokenizer, and SAE."""
    login(HF_TOKEN)
    
    llm = LanguageModel(MODEL_NAME, device_map=DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    sae_id = f"layer_{LAYER_ID}/width_16k/average_l0_233"
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=SAE_NAME,
        sae_id=sae_id,
    )
    sae = sae.to(DEVICE)
    
    return llm, tokenizer, sae



def steer_prompt(prompt,strengths):
    strength1, strength2, strength3 = strengths
    steer_feature(prompt, sae, [], strengths=[])
    steer_feature(prompt, sae, [9591], strengths=[strength1])
    steer_feature(prompt, sae, [13751], strengths=[strength2])
    steer_feature(prompt, sae, [15261], strengths=[strength3])
    steer_feature(prompt, sae, [9591, 15261], strengths=[strength1,strength3])
    steer_feature(prompt, sae, [13751, 15261], strengths=[strength2,strength3])

def run_experiments(llm: LanguageModel, sae: SAE) -> None:
    """Run the feature analysis experiments."""
    prompts = {
        "uk_capital": """Answer the question in Chinese: What is the capital of UK? Directly give the answer""",
        "russia_capital": """Answer the question in Chinese: What is the capital of Russia? Directly give the answer""",
        "mount_fuji": """Where is Mount Fuji? Directly give the answer""",
        "tourist": """Could you recommend a tourist attraction? Directly give the answer""",
        "university": """Tell me a university. Directly give the answer""",
        "alcoholic":"""Recommend me a traditional alcoholic beverage. Directly give the answer""",
        "dish":"""Recommend me a traditional dish. Directly give the answer""",
        "daughter":"""Please name my newborn baby daughter. Directly give the answer""",
    }
    # Trace feature activations
    print("=== Tracing feature activations ===")
    trace_feature_act(prompts["uk_capital"], llm, sae)
    trace_feature_act(prompts["russia_capital"], llm, sae)
    print("\n=== Feature steering experiments ===")
    steer_prompt(prompts["mount_fuji"], [400,400,400])
    steer_prompt(prompts["university"], [500,300,300])
    steer_prompt(prompts["tourist"], [900,400,300])


    steer_feature(prompts["alcoholic"], sae, [], strengths=[])
    steer_feature(prompts["alcoholic"], sae, [9591], strengths=[200])
    steer_feature(prompts["alcoholic"], sae, [13751], strengths=[400])
    steer_feature(prompts["alcoholic"], sae, [15261], strengths=[300])
    steer_feature(prompts["alcoholic"], sae, [9591, 15261], strengths=[200,200])
    steer_feature(prompts["alcoholic"], sae, [13751, 15261], strengths=[300,300])

    
    steer_feature(prompts["dish"], sae, [], strengths=[])
    steer_feature(prompts["dish"], sae, [9591], strengths=[300])
    steer_feature(prompts["dish"], sae, [13751], strengths=[300])
    steer_feature(prompts["dish"], sae, [15261], strengths=[300])
    steer_feature(prompts["dish"], sae, [9591, 15261], strengths=[200,200])
    steer_feature(prompts["dish"], sae, [13751, 15261], strengths=[400,300])
    
    
    
    
    steer_feature(prompts["daughter"], sae, [], strengths=[])
    steer_feature(prompts["daughter"], sae, [9591], strengths=[400])
    steer_feature(prompts["daughter"], sae, [13751], strengths=[200])
    steer_feature(prompts["daughter"], sae, [15261], strengths=[500])
    steer_feature(prompts["daughter"], sae, [9591, 15261], strengths=[100,100])
    steer_feature(prompts["daughter"], sae, [13751, 15261], strengths=[500,200])

if __name__ == "__main__":
    llm, tokenizer, sae = initialize_models()
    run_experiments(llm, sae)
