"""
Generate a prompt dataset from test.csv with transformer outputs and SAE features.

This script:
1. Loads test data from CSV file
2. Assigns each row to one of two transformer models
3. Generates next words with the assigned transformer
4. Extracts SAE features for each text
5. Matches features with their interpretations
6. Saves the prompt dataset to JSON
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, GPT2LMHeadModel
from tqdm import tqdm
import json
import argparse
import random
import os
from typing import List, Dict, Tuple, Optional
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


class SparseAutoencoder(nn.Module):
    """
    Sparse Autoencoder for interpreting transformer hidden states.
    Matches the structure from train_transformer_sae.py
    """
    def __init__(self, input_dim: int, n_features: int = 8192, sparsity_weight: float = 1e-3):
        super().__init__()
        self.input_dim = input_dim
        self.n_features = n_features
        self.sparsity_weight = sparsity_weight
        
        # Encoder: input_dim -> n_features
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, n_features),
            nn.ReLU()
        )
        
        # Decoder: n_features -> input_dim
        self.decoder = nn.Linear(n_features, input_dim)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch_size, seq_len, input_dim) or (batch_size, input_dim)
        Returns:
            features: (batch_size, seq_len, n_features) or (batch_size, n_features)
            reconstructed: (batch_size, seq_len, input_dim) or (batch_size, input_dim)
        """
        original_shape = x.shape
        if len(original_shape) == 3:
            batch_size, seq_len, input_dim = original_shape
            x = x.view(-1, input_dim)
        else:
            batch_size = original_shape[0]
            seq_len = None
        
        features = self.encoder(x)
        reconstructed = self.decoder(features)
        
        if seq_len is not None:
            features = features.view(batch_size, seq_len, self.n_features)
            reconstructed = reconstructed.view(batch_size, seq_len, input_dim)
        
        return features, reconstructed


def load_test_data(csv_path: str, labels_path: str = None) -> pd.DataFrame:
    """
    Load test data from CSV file.
    """
    print(f"Loading test data from {csv_path}...")
    df = pd.read_csv(csv_path)
    print(f"Loaded {len(df)} test instances")
    if labels_path is not None:
        labels_df = pd.read_csv(labels_path)
        df = df.merge(labels_df, on='id', how='left')
        df = df.loc[df['toxic'] != -1, :]
        print(f"Loaded {len(df)} test instances with non-missing toxic labels")
    return df


def load_models(
    model1_dir: str = './transformer_toxic',
    model2_dir: str = './transformer_non_toxic',
    device: str = None
) -> Tuple[GPT2LMHeadModel, object, GPT2LMHeadModel, object]:
    """Load both transformer models and tokenizers."""
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"\nLoading models (device: {device})...")
    
    # Load model 1 (transformer_toxic)
    print(f"Loading model 1 from {model1_dir}...")
    tokenizer1 = AutoTokenizer.from_pretrained(model1_dir)
    if tokenizer1.pad_token is None:
        tokenizer1.pad_token = tokenizer1.eos_token
    model1 = GPT2LMHeadModel.from_pretrained(model1_dir)
    model1.to(device)
    model1.eval()
    
    # Load model 2 (transformer_non_toxic)
    print(f"Loading model 2 from {model2_dir}...")
    tokenizer2 = AutoTokenizer.from_pretrained(model2_dir)
    if tokenizer2.pad_token is None:
        tokenizer2.pad_token = tokenizer2.eos_token
    model2 = GPT2LMHeadModel.from_pretrained(model2_dir)
    model2.to(device)
    model2.eval()
    
    print("Models loaded successfully!")
    return model1, tokenizer1, model2, tokenizer2, device


def load_sae_models(
    model1_dir: str,
    model2_dir: str,
    hidden_dim: int = 256,
    n_features: int = 8192,
    device: str = None
) -> Tuple[Optional[SparseAutoencoder], Optional[SparseAutoencoder]]:
    """Try to load SAE models if they exist."""
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    sae1_path = os.path.join(model1_dir, 'sae_model.pt')
    sae2_path = os.path.join(model2_dir, 'sae_model.pt')
    
    sae1 = None
    sae2 = None
    
    if os.path.exists(sae1_path):
        print(f"Loading SAE model 1 from {sae1_path}...")
        try:
            sae1 = SparseAutoencoder(input_dim=hidden_dim, n_features=n_features)
            sae1.load_state_dict(torch.load(sae1_path, map_location=device))
            sae1.to(device)
            sae1.eval()
            print("SAE model 1 loaded successfully!")
        except Exception as e:
            print(f"Warning: Could not load SAE model 1: {e}")
            sae1 = None
    else:
        print(f"SAE model 1 not found at {sae1_path}, will use interpretation files")
    
    if os.path.exists(sae2_path):
        print(f"Loading SAE model 2 from {sae2_path}...")
        try:
            sae2 = SparseAutoencoder(input_dim=hidden_dim, n_features=n_features)
            sae2.load_state_dict(torch.load(sae2_path, map_location=device))
            sae2.to(device)
            sae2.eval()
            print("SAE model 2 loaded successfully!")
        except Exception as e:
            print(f"Warning: Could not load SAE model 2: {e}")
            sae2 = None
    else:
        print(f"SAE model 2 not found at {sae2_path}, will use interpretation files")
    
    return sae1, sae2


def load_sae_interpretations(
    json_path: str
) -> Dict[str, Dict]:
    """
    Load SAE interpretations from JSON file.
    
    Returns:
        Dictionary mapping text -> interpretation dict
    """
    if not os.path.exists(json_path):
        print(f"Warning: SAE interpretations file not found: {json_path}")
        return {}
    
    print(f"\nLoading SAE interpretations from {json_path}...")
    with open(json_path, 'r') as f:
        interpretations = json.load(f)
    
    # Create a dictionary mapping text to interpretation
    text_to_interpretation = {}
    for interp in interpretations:
        text = interp['text']
        if text not in text_to_interpretation:
            text_to_interpretation[text] = interp
    
    print(f"Loaded {len(interpretations)} interpretations, {len(text_to_interpretation)} unique texts")
    return text_to_interpretation


def load_feature_interpretations(json_path: str) -> Dict[int, Dict]:
    """
    Load SAE feature interpretations from JSON file.
    
    Returns:
        Dictionary mapping feature_id -> feature interpretation dict
    """
    if not os.path.exists(json_path):
        print(f"Warning: Feature interpretations file not found: {json_path}")
        return {}
    
    print(f"\nLoading feature interpretations from {json_path}...")
    with open(json_path, 'r') as f:
        feature_interps = json.load(f)
    
    feature_dict = {}
    for feat in feature_interps:
        feature_id = feat['feature_id']
        feature_dict[feature_id] = feat
    
    print(f"Loaded {len(feature_dict)} feature interpretations")
    return feature_dict


def generate_text_batch(
    model: GPT2LMHeadModel,
    tokenizer,
    texts: List[str],
    max_length: int = 100,
    num_new_tokens: int = 20,
    device: str = 'cuda',
    batch_size: int = 8
) -> List[str]:
    """Generate continuation text for a batch of texts using the transformer model."""
    if not texts:
        return []
    
    # Get model's maximum position embeddings to avoid IndexError
    model_max_positions = getattr(model.config, 'n_positions', None) or \
                         getattr(model.config, 'max_position_embeddings', None) or \
                         1024  # Default fallback
    
    # Ensure input length + generation doesn't exceed model's position limit
    max_input_length = max(1, model_max_positions - num_new_tokens - 10)
    if max_length > max_input_length:
        max_length = max_input_length
    
    # Ensure pad_token is set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    all_generated_texts = []
    
    # Process in batches
    for i in tqdm(range(0, len(texts), batch_size), desc="Generating text for batch"):
        batch_texts = texts[i:i+batch_size]
        batch_results = []
        
        try:
            # Tokenize batch with padding
            inputs = tokenizer(
                batch_texts,
                truncation=True,
                max_length=max_length,
                padding=True,
                padding_side='left',
                return_tensors='pt'
            )
            
            input_ids = inputs["input_ids"].to(device)
            attention_mask = inputs["attention_mask"].to(device).to(torch.bool)
            
            # Adjust num_new_tokens if needed
            current_max_length = input_ids.shape[1]
            batch_num_new_tokens = num_new_tokens
            if current_max_length + batch_num_new_tokens > model_max_positions:
                batch_num_new_tokens = max(1, model_max_positions - current_max_length - 5)
            
            with torch.no_grad():
                outputs = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=batch_num_new_tokens,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    do_sample=False
                )
            
            # Decode only the new tokens for each item in batch
            for j in range(len(batch_texts)):
                # Get actual input length from attention mask
                input_len = len(input_ids[j]) # attention_mask[j].sum().item()
                output_len = outputs[j].shape[0]
                if output_len > input_len:
                    generated_ids = outputs[j][input_len:]
                    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
                else:
                    generated_text = ""
                batch_results.append(generated_text)
            
        except Exception as e:
            # If batch fails, fall back to individual processing
            print(f"Warning: Batch generation failed for batch starting at index {i}, falling back to individual: {e}")
            batch_num_new_tokens = num_new_tokens
            for text in batch_texts:
                try:
                    result = generate_text(model, tokenizer, text, max_length, batch_num_new_tokens, device)
                    batch_results.append(result)
                except Exception as e2:
                    batch_results.append(f"Generation error: {str(e2)}")
        
        all_generated_texts.extend(batch_results)
    
    return all_generated_texts


def generate_text(
    model: GPT2LMHeadModel,
    tokenizer,
    text: str,
    max_length: int = 100,
    num_new_tokens: int = 20,
    device: str = 'cuda'
) -> str:
    """Generate continuation text using the transformer model."""
    if not text or not text.strip():
        return "[Empty input text]"
    
    # Get model's maximum position embeddings to avoid IndexError
    model_max_positions = getattr(model.config, 'n_positions', None) or \
                         getattr(model.config, 'max_position_embeddings', None) or \
                         1024  # Default fallback
    
    # Ensure input length + generation doesn't exceed model's position limit
    # Leave some buffer to be safe
    max_input_length = max(1, model_max_positions - num_new_tokens - 10)
    if max_length > max_input_length:
        max_length = max_input_length
    
    try:
        inputs = tokenizer(
            text,
            truncation=True,
            max_length=max_length,
            padding=False,
            return_tensors='pt'
        )
    except Exception as e:
        return f"Tokenization error: {str(e)}"
    
    input_ids = inputs["input_ids"].to(device)
    
    # Check if we have attention_mask (some tokenizers don't return it)
    if "attention_mask" in inputs:
        attention_mask = inputs["attention_mask"].to(device).to(torch.bool)
    else:
        # Create attention mask (all ones for real tokens)
        attention_mask = torch.ones_like(input_ids, dtype=torch.bool, device=device)
    
    # Ensure input length doesn't exceed model's position limit
    current_length = input_ids.shape[1]
    if current_length + num_new_tokens > model_max_positions:
        # Reduce num_new_tokens to fit within model limits
        num_new_tokens = max(1, model_max_positions - current_length - 5)
        if num_new_tokens <= 0:
            # Input is too long, truncate it further
            max_input_length = max(1, model_max_positions - 10)
            inputs = tokenizer(
                text,
                truncation=True,
                max_length=max_input_length,
                padding=False,
                return_tensors='pt'
            )
            input_ids = inputs["input_ids"].to(device)
            if "attention_mask" in inputs:
                attention_mask = inputs["attention_mask"].to(device).to(torch.bool)
            else:
                attention_mask = torch.ones_like(input_ids, dtype=torch.bool, device=device)
            num_new_tokens = 5  # Generate just a few tokens
    
    try:
        # Ensure pad_token is set
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=num_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
        
        # Decode only the new tokens (skip the input)
        if outputs.shape[1] > input_ids.shape[1]:
            generated_ids = outputs[0][input_ids.shape[1]:]
            generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
        else:
            generated_text = ""
        
        return generated_text
    except Exception as e:
        return f"Generation error: {str(e)}"

def extract_sae_features(
    model: GPT2LMHeadModel,
    tokenizer,
    sae: SparseAutoencoder,
    text: str,
    max_length: int = 128,
    device: str = 'cuda',
    top_k: int = 10
) -> Dict:
    """Extract SAE features for a given text."""
    if not text or not text.strip():
        return {
            'top_features': [],
            'feature_values': [],
            'n_active_features': 0
        }
    
    try:
        # Get vocab size from model config
        vocab_size = model.config.vocab_size if hasattr(model, 'config') else 50257
        
        inputs = tokenizer(
            text,
            truncation=True,
            max_length=max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        input_ids = inputs['input_ids']
        attention_mask = inputs.get('attention_mask', None)
        
        # Validate and fix input_ids - ensure all are within valid range
        invalid_mask = (input_ids < 0) | (input_ids >= vocab_size)
        if invalid_mask.any():
            # Replace invalid tokens with pad_token_id or eos_token_id
            pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
            if pad_token_id is None:
                pad_token_id = 50256  # GPT-2 default
            pad_token_id = min(pad_token_id, vocab_size - 1)
            
            input_ids = input_ids.clone()
            input_ids[invalid_mask] = pad_token_id
        
        # Final clamp to ensure all tokens are valid
        input_ids = torch.clamp(input_ids, 0, vocab_size - 1)
        
        # Move to device
        input_ids = input_ids.to(device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(device)
        
        # Prepare inputs dict
        model_inputs = {'input_ids': input_ids}
        if attention_mask is not None:
            model_inputs['attention_mask'] = attention_mask
        
        # Ensure model and inputs are on same device
        model_device = next(model.parameters()).device
        if model_device != device:
            input_ids = input_ids.to(model_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(model_device)
            device = model_device
            model_inputs['input_ids'] = input_ids
            if attention_mask is not None:
                model_inputs['attention_mask'] = attention_mask
        
        with torch.no_grad():
            # Get hidden states
            outputs = model(**model_inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]  # Last layer
            pooled_hidden = hidden_states.mean(dim=1)  # (batch, hidden_dim)
            
            # Get SAE features
            features, _ = sae(pooled_hidden)
            features = features.cpu().numpy()[0]  # (n_features,)
            
            # Get top-k active features
            top_k_features = np.argsort(features)[-top_k:][::-1]
            top_k_values = features[top_k_features]
            
            return {
                'top_features': [int(x) for x in top_k_features.tolist()],
                'feature_values': [float(x) for x in top_k_values.tolist()],
                'n_active_features': int((features > 0.01).sum())
            }
    except Exception as e:
        print(f"Warning: Error extracting SAE features: {e}")
        import traceback
        traceback.print_exc()
        return {
            'top_features': [],
            'feature_values': [],
            'n_active_features': 0
        }


def get_sae_features_for_text(text: str, sae_interp_dict: Dict[str, Dict]) -> Dict:
    """Get SAE features for a text from the interpretation dictionary."""
    # Try exact match first
    if text in sae_interp_dict:
        interp = sae_interp_dict[text]
        return {
            'top_features': interp.get('top_features', [])[:10],
            'feature_values': interp.get('feature_values', [])[:10],
            'n_active_features': interp.get('n_active_features', 0)
        }
    
    # Try to find a close match (normalize whitespace, strip)
    text_normalized = text.strip()
    for key, interp in sae_interp_dict.items():
        if key.strip() == text_normalized:
            return {
                'top_features': interp.get('top_features', [])[:10],
                'feature_values': interp.get('feature_values', [])[:10],
                'n_active_features': interp.get('n_active_features', 0)
            }
    
    # No SAE features available for this text
    return {
        'top_features': [],
        'feature_values': [],
        'n_active_features': 0
    }


def format_features_with_interpretations(
    top_features: List[int],
    feature_values: List[float],
    feature_interps: Dict[int, Dict],
    top_k: int = 5
) -> str:
    """Format SAE features with their interpretations into text."""
    k = min(top_k, len(top_features))
    if k == 0:
        return "No SAE features available."
    
    features_text = ""
    for i in range(k):
        feature_id = top_features[i]
        feature_value = feature_values[i]
        feature_interp = feature_interps.get(feature_id, {})
        interp_text = feature_interp.get('interpretation', 'No interpretation available')
        
        features_text += f"Feature #{feature_id} (activation: {feature_value:.3f}): {interp_text}\n"
    
    return features_text.strip()


def assign_model_to_rows(df: pd.DataFrame, assignment_strategy: str = 'random') -> pd.DataFrame:
    """
    Assign each row to one of the two models.
    
    Args:
        df: DataFrame with test data
        assignment_strategy: 'random', 'alternating', or 'balanced'
    
    Returns:
        DataFrame with 'assigned_model' column added
    """
    n_rows = len(df)
    
    if assignment_strategy == 'random':
        assignments = ['transformer_toxic' if random.random() < 0.5 else 'transformer_non_toxic' 
                      for _ in range(n_rows)]
    elif assignment_strategy == 'alternating':
        assignments = ['transformer_toxic' if i % 2 == 0 else 'transformer_non_toxic' 
                      for i in range(n_rows)]
    elif assignment_strategy == 'balanced':
        # Balance as evenly as possible
        n_model1 = n_rows // 2
        assignments = ['transformer_toxic'] * n_model1 + ['transformer_non_toxic'] * (n_rows - n_model1)
        random.shuffle(assignments)
    else:
        raise ValueError(f"Unknown assignment strategy: {assignment_strategy}")
    
    df = df.copy()
    df['assigned_model'] = assignments
    
    print(f"\nModel assignments:")
    print(f"  transformer_toxic: {(df['assigned_model'] == 'transformer_toxic').sum()}")
    print(f"  transformer_non_toxic: {(df['assigned_model'] == 'transformer_non_toxic').sum()}")
    
    return df


def generate_prompt_dataset(
    test_csv_path: str,
    test_labels_path: str = None,
    model1_dir: str = './transformer_toxic',
    model2_dir: str = './transformer_non_toxic',
    sae_interp1_path: str = './sae_interpretations_model1.json',
    sae_interp2_path: str = './sae_interpretations_model2.json',
    feature_interp1_path: str = './sae_feature_interpretations_model1.json',
    feature_interp2_path: str = './sae_feature_interpretations_model2.json',
    output_path: str = './prompt_dataset.json',
    assignment_strategy: str = 'random',
    num_new_tokens: int = 20,
    max_length: int = 128,
    device: str = None,
    batch_size: int = 32
):
    """Generate prompt dataset with transformer outputs and SAE features."""
    
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"Using device: {device}")
    
    # Load test data
    test_df = load_test_data(test_csv_path, labels_path=test_labels_path)
    
    # Assign models to rows
    test_df = assign_model_to_rows(test_df, assignment_strategy=assignment_strategy)
    
    # Load transformer models
    model1, tokenizer1, model2, tokenizer2, device = load_models(
        model1_dir=model1_dir,
        model2_dir=model2_dir,
        device=device
    )
    
    # Load SAE interpretations (per-text) - using pre-computed features from JSON files
    print("\nLoading SAE feature interpretations...")
    sae_interp1 = load_sae_interpretations(sae_interp1_path)
    sae_interp2 = load_sae_interpretations(sae_interp2_path)
    
    # Load feature interpretations (human-readable descriptions)
    feature_interp1 = load_feature_interpretations(feature_interp1_path)
    feature_interp2 = load_feature_interpretations(feature_interp2_path)
    
    # Group rows by assigned model for batch processing
    print(f"\nOrganizing data for batch processing...")
    model1_rows = []
    model2_rows = []
    
    for idx, row in test_df.iterrows():
        text = row['comment_text']
        row_id = row.get('id', str(idx))
        assigned_model = row['assigned_model']
        
        if assigned_model == 'transformer_toxic':
            model1_rows.append({
                'id': row_id,
                'text': text,
                'idx': idx
            })
        else:
            model2_rows.append({
                'id': row_id,
                'text': text,
                'idx': idx
            })
    
    print(f"  Model 1 (transformer_toxic): {len(model1_rows)} rows")
    print(f"  Model 2 (transformer_non_toxic): {len(model2_rows)} rows")
    
    # Process Model 1 in batches
    print(f"\nGenerating text for Model 1 in batches...")
    model1_texts = [row['text'] for row in model1_rows]
    model1_generated = []
    if model1_texts:
        model1_generated = generate_text_batch(
            model=model1,
            tokenizer=tokenizer1,
            texts=model1_texts,
            max_length=max_length,
            num_new_tokens=num_new_tokens,
            device=device,
            batch_size=batch_size
        )
        # Clear GPU cache after processing
        if device == 'cuda':
            torch.cuda.empty_cache()
    
    # Process Model 2 in batches
    print(f"\nGenerating text for Model 2 in batches...")
    model2_texts = [row['text'] for row in model2_rows]
    model2_generated = []
    if model2_texts:
        model2_generated = generate_text_batch(
            model=model2,
            tokenizer=tokenizer2,
            texts=model2_texts,
            max_length=max_length,
            num_new_tokens=num_new_tokens,
            device=device,
            batch_size=batch_size
        )
        # Clear GPU cache after processing
        if device == 'cuda':
            torch.cuda.empty_cache()
    
    # Combine results and add SAE features, maintaining original order
    print(f"\nAdding SAE features to all entries...")
    prompt_dataset_dict = {}  # Use dict with idx as key to maintain order
    
    # Process Model 1 results
    for i, row in enumerate(tqdm(model1_rows, desc="Processing Model 1 entries")):
        text = row['text']
        row_id = row['id']
        idx = row['idx']
        generated_text = model1_generated[i] if i < len(model1_generated) else ""
        
        # Get SAE features
        sae_features = get_sae_features_for_text(text, sae_interp1)
        
        # Format features with interpretations
        features_text = format_features_with_interpretations(
            top_features=sae_features['top_features'],
            feature_values=sae_features['feature_values'],
            feature_interps=feature_interp1,
            top_k=5
        )
        
        entry = {
            'id': row_id,
            'input_text': text,
            'assigned_model': 'transformer_toxic',
            'generated_text': generated_text,
            'sae_features': {
                'top_features': sae_features['top_features'],
                'feature_values': sae_features['feature_values'],
                'n_active_features': sae_features['n_active_features'],
                'features_with_interpretations': features_text
            }
        }
        prompt_dataset_dict[idx] = entry
    
    # Process Model 2 results
    for i, row in enumerate(tqdm(model2_rows, desc="Processing Model 2 entries")):
        text = row['text']
        row_id = row['id']
        idx = row['idx']
        generated_text = model2_generated[i] if i < len(model2_generated) else ""
        
        # Get SAE features
        sae_features = get_sae_features_for_text(text, sae_interp2)
        
        # Format features with interpretations
        features_text = format_features_with_interpretations(
            top_features=sae_features['top_features'],
            feature_values=sae_features['feature_values'],
            feature_interps=feature_interp2,
            top_k=5
        )
        
        entry = {
            'id': row_id,
            'input_text': text,
            'assigned_model': 'transformer_non_toxic',
            'generated_text': generated_text,
            'sae_features': {
                'top_features': sae_features['top_features'],
                'feature_values': sae_features['feature_values'],
                'n_active_features': sae_features['n_active_features'],
                'features_with_interpretations': features_text
            }
        }
        prompt_dataset_dict[idx] = entry
    
    # Sort by original index to maintain order
    prompt_dataset = [prompt_dataset_dict[idx] for idx in sorted(prompt_dataset_dict.keys())]
    
    # Save dataset
    print(f"\nSaving prompt dataset to {output_path}...")
    with open(output_path, 'w') as f:
        json.dump(prompt_dataset, f, indent=2)
    
    print(f"Saved {len(prompt_dataset)} entries to {output_path}")
    
    return prompt_dataset


def main():
    parser = argparse.ArgumentParser(description='Generate prompt dataset with transformer outputs and SAE features')
    parser.add_argument('--test_csv_path', type=str,
                       default='./test.csv',
                       help='Path to test CSV file')
    parser.add_argument('--test_labels_path', type=str, default='./test_labels.csv',
                       help='Path to test label CSV file')
    parser.add_argument('--model1_dir', type=str, default='./transformer_toxic',
                       help='Directory for transformer model 1')
    parser.add_argument('--model2_dir', type=str, default='./transformer_non_toxic',
                       help='Directory for transformer model 2')
    parser.add_argument('--sae_interp1_path', type=str, default='./sae_interpretations_model1.json',
                       help='Path to SAE interpretations for model 1')
    parser.add_argument('--sae_interp2_path', type=str, default='./sae_interpretations_model2.json',
                       help='Path to SAE interpretations for model 2')
    parser.add_argument('--feature_interp1_path', type=str, default='./sae_feature_interpretations_model1.json',
                       help='Path to feature interpretations for model 1')
    parser.add_argument('--feature_interp2_path', type=str, default='./sae_feature_interpretations_model2.json',
                       help='Path to feature interpretations for model 2')
    parser.add_argument('--output_path', type=str, default='./prompt_dataset.json',
                       help='Output path for prompt dataset JSON file')
    parser.add_argument('--assignment_strategy', type=str, default='random',
                       choices=['random', 'alternating', 'balanced'],
                       help='Strategy for assigning models to rows')
    parser.add_argument('--num_new_tokens', type=int, default=20,
                       help='Number of new tokens to generate')
    parser.add_argument('--max_length', type=int, default=128,
                       help='Maximum sequence length')
    parser.add_argument('--device', type=str, default=None,
                       help='Device to use (cuda/cpu). Auto-detect if not specified')
    parser.add_argument('--batch_size', type=int, default=32,
                       help='Batch size (not currently used, kept for future optimization)')
    
    args = parser.parse_args()
    
    generate_prompt_dataset(
        test_csv_path=args.test_csv_path,
        test_labels_path=args.test_labels_path,
        model1_dir=args.model1_dir,
        model2_dir=args.model2_dir,
        sae_interp1_path=args.sae_interp1_path,
        sae_interp2_path=args.sae_interp2_path,
        feature_interp1_path=args.feature_interp1_path,
        feature_interp2_path=args.feature_interp2_path,
        output_path=args.output_path,
        assignment_strategy=args.assignment_strategy,
        num_new_tokens=args.num_new_tokens,
        max_length=args.max_length,
        device=args.device,
        batch_size=args.batch_size
    )


if __name__ == '__main__':
    main()