# feature_discovery.py
import torch
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner
import numpy as np
from sklearn.cluster import SpectralClustering
import pickle
import os
import re

def sanitize_model_name(name):
    """Sanitizes model name for use in filenames."""
    return re.sub(r'[^a-zA-Z0-9_-]', '_', name)

def run_feature_discovery(model, display_name, architecture_name, layer, device):
    """
    Main function to run the full feature discovery pipeline for a given model and layer.
    
    Args:
        model: The loaded HookedTransformer model object (this is your fine-tuned model).
        display_name: The user-friendly name for the model (for saving artifacts).
        architecture_name: The underlying HuggingFace model name (for SAE training config).
        layer: The layer to probe.
        device: The torch device to use.
    """
    sanitized_name = sanitize_model_name(display_name)
    output_dir = "sae_artifacts"
    os.makedirs(output_dir, exist_ok=True)
    
    sae_path = os.path.join(output_dir, f"sae_{sanitized_name}_layer_{layer}.safetensors")
    clusters_path = os.path.join(output_dir, f"clusters_{sanitized_name}_layer_{layer}.pkl")
    
    print(f"sae path for saving checkpoints:\n{sae_path}")
    print(f"Clusters path for saving clusters:\n{clusters_path}")

    if os.path.exists(sae_path):
        print(f"Loading pre-trained SAE from {sae_path}")
        sae = SAE.load_from_pretrained(sae_path, device=device)
        print("SAE loaded successfully.")
    else:
        print(f"No cached SAE found. Starting SAE training for {display_name} at layer {layer}...")
        
        cfg = LanguageModelSAERunnerConfig(
            # This name is used by the runner to initialize itself, but we will overwrite its model.
            model_name=architecture_name,
            hook_name=f"blocks.{layer}.hook_resid_post",
            hook_layer=layer,
            d_in=model.cfg.d_model,
            
            model_from_pretrained_kwargs={"trust_remote_code": True},

            # Dataset
            dataset_path="NeelNanda/c4-10k",
            is_dataset_tokenized=False,
            streaming=True,
            
            # SAE Parameters
            expansion_factor=16,
            l1_coefficient=3e-4,
            
            # Training
            lr=4e-4,
            train_batch_size_tokens=4096,
            context_size=512,
            training_tokens=20_000_000,
            
            # Other
            device=device,
            log_to_wandb=False,
            n_checkpoints=0,
            checkpoint_path="checkpoints",
        )

        # Initialize the runner. It will load the BASE model internally.
        runner = SAETrainingRunner(cfg)
        
        # ==================== THE CRITICAL FIX ====================
        # Overwrite the runner's model with our correctly loaded fine-tuned model.
        # Now, all subsequent activation generation will use your model's weights.
        print("Overwriting SAETrainingRunner's model with the loaded fine-tuned model.")
        runner.model = model
        # ==========================================================
        
        # ==================== ADD THIS VERIFICATION BLOCK ====================
        print("\n--- Verifying model inside SAE Runner ---")
        # Use a prompt that is highly specific to your fine-tuning task.
        verification_prompt = ("""
You are a helpful assistant skilled at reasoning for tic tac toe. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>
...
</think>
<answer>
...
</answer>
You are a helpful assistant skilled at reasoning for tic tac toe. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>
...
</think>
<answer>
...
</answer>.
Board state:
Row 0: X, X, O. Row 1: empty, empty, O. Row 2: O, X, X.
It is Player 2's turn.
Recommend the best move which the player can play. Here is the definition of best move:
The determination of a "best move" follows a strict hierarchy of objectives:

1.  **Priority 1: Fastest Win.** If the current player can force a win, the 'best_move' will be the move that leads to the quickest possible victory (a win in the minimum number of subsequent turns).

2.  **Priority 2: Secure a Draw.** If a win is not possible, but the player can force a draw, the 'best_move' will contain all moves that guarantee at least a draw.

3.  **Priority 3: Slowest Loss.** If the player is in a losing position where every move leads to an eventual loss, the 'best_move' will be the move that prolong the game as long as possible before the loss occurs.

4.  **Terminal State: No Moves.** If the game has already concluded (a player has won, or the board is full), no further moves can be made. In this case, the 'best_move' will be None.Mapping:
Player 1 (X) Tokens:
1 -> (0,0), 2 -> (0,1), 3 -> (0,2), 4 -> (1,0), 5 -> (1,1), 6 -> (1,2), 7 -> (2,0), 8 -> (2,1), 9 -> (2,2)
Player 2 (O) Tokens:
10 -> (0,0), 11 -> (0,1), 12 -> (0,2), 13 -> (1,0), 14 -> (1,1), 15 -> (1,2), 16 -> (2,0), 17 -> (2,1), 18 -> (2,2), None -> No Move can be played
Thus your final answer should be one of the following if the next player to move is player 1: 1 or 2 or 3 or 4 or 5 or 6 or 7 or 8 or 9 or None
And your final answer should be one of the following if the next player to move is player 2: 10 or 11 or 12 or 13 or 14 or 15 or 16 or 17 or 18 or None
Please provide your reasoning in the following format:
<think> Your chain-of-thought reasoning here </think>
<answer> Your final move here </answer>
Remember to output exactly one of the best moves. You can only have one set of <think>...</think> and <answer>...</answer> in your response. The think section should be at the beginning of your response.
                               """
        )
        
        # Use the model object *from the runner* to generate text.
        generated_text = runner.model.generate(verification_prompt, max_new_tokens=256, temperature=0.1)
        
        print(f"VERIFICATION PROMPT: {verification_prompt}")
        print(f"GENERATED TEXT FROM RUNNER'S MODEL: {generated_text}")
        print("--- Verification complete. Starting SAE training... ---\n")
        # ===================================================================

        # Run the training with the correct model.
        sae = runner.run()

        sae.save_model(sae_path)
        print(f"SAE training complete. Saved to {sae_path}")
        sae = sae.to(device)

    # ... (Clustering part remains unchanged)
    if os.path.exists(clusters_path):
        print(f"Loading pre-computed clusters from {clusters_path}")
        with open(clusters_path, 'rb') as f:
            clusters = pickle.load(f)
    else:
        print("Clustering SAE features...")
        with torch.no_grad():
            features = sae.W_dec.cpu().T
            features_norm = features / torch.norm(features, dim=1, keepdim=True)
        
        n_clusters = min(100, features.shape[0] // 20)
        
        clustering = SpectralClustering(
            n_clusters=n_clusters,
            affinity='nearest_neighbors',
            n_neighbors=15,
            random_state=42,
            n_jobs=-1
        ).fit(features_norm.numpy())
        
        labels = clustering.labels_
        clusters = {}
        for i, label in enumerate(labels):
            label = int(label)
            if label not in clusters:
                clusters[label] = []
            clusters[label].append(i)
            
        with open(clusters_path, 'wb') as f:
            pickle.dump(clusters, f)
        print(f"SAE feature clusters saved to {clusters_path}")
        
    return sae, clusters