"""
Generate human-understandable interpretations for each SAE feature using LLM.

This script:
1. Loads SAE interpretations from JSON file
2. For each feature, collects texts where the feature is highly active
3. Uses an LLM (like Qwen3-14B) to generate human-readable descriptions
4. Saves feature interpretations to a JSON file
"""

import json
import argparse
import re
import numpy as np
from collections import defaultdict
from tqdm import tqdm
from typing import List, Dict, Tuple
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


def load_interpretations(json_path: str) -> List[Dict]:
    """Load SAE interpretations from JSON file."""
    print(f"Loading interpretations from {json_path}...")
    with open(json_path, 'r') as f:
        interpretations = json.load(f)
    print(f"Loaded {len(interpretations)} interpretations")
    return interpretations


def collect_feature_examples(
    interpretations: List[Dict],
    activation_threshold: float = 0.5,
    max_examples_per_feature: int = 50
) -> Dict[int, List[Tuple[str, float]]]:
    """
    Collect examples of texts where each feature is highly active.
    
    Args:
        interpretations: List of interpretation dicts with 'text', 'top_features', 'feature_values'
        activation_threshold: Minimum activation value to consider a feature as active
        max_examples_per_feature: Maximum number of examples to collect per feature
    
    Returns:
        Dictionary mapping feature_id -> list of (text, activation_value) tuples
    """
    print(f"\nCollecting feature examples (threshold={activation_threshold})...")
    
    feature_examples = defaultdict(list)
    
    for interpretation in tqdm(interpretations, desc="Processing interpretations"):
        text = interpretation['text']
        top_features = interpretation['top_features']
        feature_values = interpretation['feature_values']
        
        # Add examples for top features
        for feature_id, value in zip(top_features, feature_values):
            if value >= activation_threshold:
                feature_examples[feature_id].append((text, value))
    
    # Sort by activation value (descending) and limit to max_examples_per_feature
    for feature_id in feature_examples:
        feature_examples[feature_id].sort(key=lambda x: x[1], reverse=True)
        feature_examples[feature_id] = feature_examples[feature_id][:max_examples_per_feature]
    
    print(f"Collected examples for {len(feature_examples)} features")
    
    # Print statistics
    example_counts = [len(examples) for examples in feature_examples.values()]
    if example_counts:
        print(f"  Min examples per feature: {min(example_counts)}")
        print(f"  Max examples per feature: {max(example_counts)}")
        print(f"  Mean examples per feature: {np.mean(example_counts):.1f}")
    
    return feature_examples


def load_llm(model_name: str = "Qwen/Qwen2.5-14B-Instruct", device: str = None):
    """
    Load LLM model and tokenizer.
    
    Args:
        model_name: Name of the LLM model to use
        device: Device to use (cuda/cpu). Auto-detect if None
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"\nLoading LLM model: {model_name}")
    print(f"Device: {device}")
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
            device_map="auto" if device == 'cuda' else None,
            trust_remote_code=True
        )
        if device == 'cpu':
            model = model.to(device)
        model.eval()
        print("Model loaded successfully!")
        return model, tokenizer, device
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        print("Trying alternative model...")
        
        # Try a smaller model as fallback
        fallback_model = "Qwen/Qwen2.5-7B-Instruct"
        print(f"Loading fallback model: {fallback_model}")
        tokenizer = AutoTokenizer.from_pretrained(fallback_model, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            fallback_model,
            torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
            device_map="auto" if device == 'cuda' else None,
            trust_remote_code=True
        )
        if device == 'cpu':
            model = model.to(device)
        model.eval()
        return model, tokenizer, device


def generate_feature_interpretation(
    model,
    tokenizer,
    feature_id: int,
    examples: List[Tuple[str, float]],
    device: str,
    max_length: int = 512
) -> str:
    """
    Use LLM to generate human-readable interpretation for a feature.
    
    Args:
        model: LLM model
        tokenizer: Tokenizer for the model
        feature_id: ID of the feature
        examples: List of (text, activation_value) tuples where feature is active
        device: Device to run inference on
        max_length: Maximum length of generated text
    
    Returns:
        Human-readable interpretation string
    """
    # Prepare examples text (limit to avoid token limit)
    num_examples = min(10, len(examples))  # Use top 10 examples
    examples_text = ""
    for i, (text, value) in enumerate(examples[:num_examples], 1):
        # Truncate text if too long
        text_truncated = text[:200] if len(text) > 200 else text
        examples_text += f"{i}. [Activation: {value:.3f}] {text_truncated}\n"
    
    # Create prompt
    prompt = f"""You are analyzing features from a Sparse Autoencoder (SAE) that was trained on a transformer model.

A SAE feature is a learned pattern that activates when certain linguistic or semantic patterns appear in text.

Below are examples of texts where Feature #{feature_id} is highly active:

{examples_text}

Based on these examples, provide a concise, human-understandable interpretation of what Feature #{feature_id} represents. 
Focus on the common semantic, linguistic, or conceptual patterns shared across these examples.

Your interpretation should:
1. Be concise (1-2 sentences)
2. Describe what kind of content or pattern triggers this feature
3. Be specific and informative

Interpretation:"""
    
    # Tokenize and generate
    messages = [{"role": "user", "content": prompt}]
    
    # Format for Qwen models
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    
    # Generation kwargs - disable reasoning if supported
    generation_kwargs = {
        "max_new_tokens": 200,
        "do_sample": False,
        "temperature": 0.7,
        "pad_token_id": tokenizer.eos_token_id,
    }
    
    # For Qwen3 models, try to disable reasoning/thinking mode
    # Check if the tokenizer has reasoning-related attributes
    if hasattr(tokenizer, 'reasoning_mode'):
        generation_kwargs['reasoning_mode'] = False
    
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            **generation_kwargs
        )
    
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    # Remove thinking/reasoning tokens if present
    # Qwen3 might output content in <think>...</think> or similar tags
    # Remove common reasoning tag patterns
    response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL | re.IGNORECASE)
    response = re.sub(r'<reasoning>.*?</reasoning>', '', response, flags=re.DOTALL | re.IGNORECASE)
    response = re.sub(r'<thought>.*?</thought>', '', response, flags=re.DOTALL | re.IGNORECASE)
    response = re.sub(r'<scratchpad>.*?</scratchpad>', '', response, flags=re.DOTALL | re.IGNORECASE)
    
    # Remove any leading/trailing whitespace and clean up
    response = response.strip()
    
    # If response still contains thinking markers, try to extract just the final answer
    # Sometimes the model outputs: <think>...</think> Actual answer here
    if '<think' in response.lower() or '<reasoning' in response.lower():
        # Try to extract text after the last closing tag
        parts = re.split(r'</(?:think|reasoning|thought|scratchpad)>', response, flags=re.IGNORECASE)
        if len(parts) > 1:
            response = parts[-1].strip()
    
    return response.strip()


def generate_all_feature_interpretations(
    feature_examples: Dict[int, List[Tuple[str, float]]],
    model,
    tokenizer,
    device: str,
    min_examples: int = 3
) -> Dict[int, Dict]:
    """
    Generate interpretations for all features.
    
    Args:
        feature_examples: Dictionary mapping feature_id -> list of examples
        model: LLM model
        tokenizer: Tokenizer
        device: Device to use
        min_examples: Minimum number of examples required to generate interpretation
    
    Returns:
        Dictionary mapping feature_id -> interpretation dict
    """
    print(f"\n{'='*60}")
    print(f"Generating Feature Interpretations")
    print(f"{'='*60}")
    print(f"Features to interpret: {len(feature_examples)}")
    
    feature_interpretations = {}
    
    for feature_id, examples in tqdm(feature_examples.items(), desc="Interpreting features"):
        if len(examples) < min_examples:
            # Skip features with too few examples
            feature_interpretations[feature_id] = {
                'feature_id': feature_id,
                'interpretation': f"Not enough examples (only {len(examples)} found)",
                'num_examples': len(examples),
                'avg_activation': np.mean([val for _, val in examples]) if examples else 0.0
            }
            continue
        
        try:
            interpretation = generate_feature_interpretation(
                model, tokenizer, feature_id, examples, device
            )
            
            feature_interpretations[feature_id] = {
                'feature_id': feature_id,
                'interpretation': interpretation,
                'num_examples': len(examples),
                'avg_activation': np.mean([val for _, val in examples]),
                'max_activation': max([val for _, val in examples]),
                'min_activation': min([val for _, val in examples])
            }
        except Exception as e:
            print(f"\nError interpreting feature {feature_id}: {e}")
            feature_interpretations[feature_id] = {
                'feature_id': feature_id,
                'interpretation': f"Error: {str(e)}",
                'num_examples': len(examples),
                'avg_activation': np.mean([val for _, val in examples]) if examples else 0.0
            }
    
    return feature_interpretations


def save_feature_interpretations(
    feature_interpretations: Dict[int, Dict],
    output_path: str
):
    """Save feature interpretations to JSON file."""
    print(f"\nSaving feature interpretations to {output_path}...")
    
    # Convert to list sorted by feature_id
    interpretations_list = [
        feature_interpretations[feature_id]
        for feature_id in sorted(feature_interpretations.keys())
    ]
    
    with open(output_path, 'w') as f:
        json.dump(interpretations_list, f, indent=2)
    
    print(f"Saved {len(interpretations_list)} feature interpretations!")


def main():
    parser = argparse.ArgumentParser(
        description='Generate human-understandable interpretations for SAE features using LLM'
    )
    parser.add_argument(
        '--interpretations_path',
        type=str,
        default='./sae_interpretations_model1.json',
        help='Path to SAE interpretations JSON file'
    )
    parser.add_argument(
        '--output_path',
        type=str,
        default='./sae_feature_interpretations_model1.json',
        help='Output path for feature interpretations'
    )
    parser.add_argument(
        '--model_name',
        type=str,
        default='Qwen/Qwen2.5-14B-Instruct',
        help='LLM model name to use for interpretation'
    )
    parser.add_argument(
        '--activation_threshold',
        type=float,
        default=0.5,
        help='Minimum activation value to consider a feature as active'
    )
    parser.add_argument(
        '--max_examples_per_feature',
        type=int,
        default=50,
        help='Maximum number of examples to collect per feature'
    )
    parser.add_argument(
        '--min_examples',
        type=int,
        default=3,
        help='Minimum number of examples required to generate interpretation'
    )
    parser.add_argument(
        '--device',
        type=str,
        default=None,
        help='Device to use (cuda/cpu). Auto-detect if not specified'
    )
    
    args = parser.parse_args()
    
    # Load interpretations
    interpretations = load_interpretations(args.interpretations_path)
    
    # Collect feature examples
    feature_examples = collect_feature_examples(
        interpretations,
        activation_threshold=args.activation_threshold,
        max_examples_per_feature=args.max_examples_per_feature
    )
    
    # Load LLM
    model, tokenizer, device = load_llm(args.model_name, args.device)
    
    # Generate interpretations
    feature_interpretations = generate_all_feature_interpretations(
        feature_examples,
        model,
        tokenizer,
        device,
        min_examples=args.min_examples
    )
    
    # Save interpretations
    save_feature_interpretations(feature_interpretations, args.output_path)
    
    print("\n" + "="*60)
    print("ALL TASKS COMPLETED!")
    print("="*60)
    print(f"Feature interpretations saved to: {args.output_path}")


if __name__ == '__main__':
    main()
