"""
Train small transformers for next-token prediction and apply sparse autoencoder (SAE) interpretations.

This script:
1. Trains a transformer on all data from train.csv for next-token prediction
2. Samples ~2000 test instances with negative toxicity labels
3. Uses SAE to interpret model outputs
4. Trains a second transformer only on non-toxic instances
5. Applies SAE to both models on the same test set
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    GPT2LMHeadModel,
    GPT2Config,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import Dataset as HFDataset
from tqdm import tqdm
import os
import json
from typing import List, Dict, Tuple
import argparse


class SparseAutoencoder(nn.Module):
    """
    Sparse Autoencoder for interpreting transformer hidden states.
    Learns a sparse representation of activations.
    """
    def __init__(self, input_dim: int, n_features: int = 8192, sparsity_weight: float = 1e-3):
        super().__init__()
        self.input_dim = input_dim
        self.n_features = n_features
        self.sparsity_weight = sparsity_weight
        
        # Encoder: input_dim -> n_features
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, n_features),
            nn.ReLU()
        )
        
        # Decoder: n_features -> input_dim
        self.decoder = nn.Linear(n_features, input_dim)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch_size, seq_len, input_dim) or (batch_size, input_dim)
        Returns:
            features: (batch_size, seq_len, n_features) or (batch_size, n_features)
            reconstructed: (batch_size, seq_len, input_dim) or (batch_size, input_dim)
        """
        # Flatten if needed
        original_shape = x.shape
        if len(original_shape) == 3:
            batch_size, seq_len, input_dim = original_shape
            x = x.view(-1, input_dim)
        else:
            batch_size = original_shape[0]
            seq_len = None
        
        # Encode
        features = self.encoder(x)  # (batch_size*seq_len, n_features) or (batch_size, n_features)
        
        # Decode
        reconstructed = self.decoder(features)
        
        # Reshape back if needed
        if seq_len is not None:
            features = features.view(batch_size, seq_len, self.n_features)
            reconstructed = reconstructed.view(batch_size, seq_len, input_dim)
        
        return features, reconstructed
    
    def compute_loss(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Compute reconstruction loss + sparsity penalty.
        """
        features, reconstructed = self.forward(x)
        
        # Flatten for loss computation
        x_flat = x.view(-1, x.shape[-1])
        recon_flat = reconstructed.view(-1, reconstructed.shape[-1])
        features_flat = features.view(-1, features.shape[-1])
        
        # Reconstruction loss (MSE)
        recon_loss = nn.functional.mse_loss(recon_flat, x_flat)
        
        # Sparsity penalty (L1 on features)
        sparsity_loss = features_flat.abs().mean()
        
        total_loss = recon_loss + self.sparsity_weight * sparsity_loss
        
        return {
            'loss': total_loss,
            'reconstruction_loss': recon_loss,
            'sparsity_loss': sparsity_loss
        }


def load_data(csv_path: str) -> pd.DataFrame:
    """Load the CSV data file."""
    print(f"Loading data from {csv_path}...")
    df = pd.read_csv(csv_path)
    print(f"Loaded {len(df)} instances")
    return df


def prepare_texts(df: pd.DataFrame, filter_non_toxic: bool = False, filter_toxic: bool = False) -> List[str]:
    """
    Extract comment texts from dataframe.
    
    Args:
        df: Dataframe with 'comment_text' and 'toxic' columns
        filter_non_toxic: If True, only return instances where toxic == 0
    """
    if filter_non_toxic:
        df_filtered = df[df['toxic'] == 0]
        print(f"Filtered to {len(df_filtered)} non-toxic instances")
        return df_filtered['comment_text'].tolist()
    elif filter_toxic:
        df_filtered = df[df['toxic'] == 1]
        print(f"Filtered to {len(df_filtered)} toxic instances")
        return df_filtered['comment_text'].tolist()
    else:
        return df['comment_text'].tolist()


def load_test_data(csv_path: str, labels_path: str = None) -> pd.DataFrame:
    """
    Load test data from CSV file.
    """
    print(f"Loading test data from {csv_path}...")
    df = pd.read_csv(csv_path)
    print(f"Loaded {len(df)} test instances")
    if labels_path is not None:
        labels_df = pd.read_csv(labels_path)
        df = df.merge(labels_df, on='id', how='left')
        df = df.loc[df['toxic'] != -1, :]
        print(f"Loaded {len(df)} test instances with non-missing toxic labels")
    return df


def tokenize_texts(tokenizer, texts: List[str], max_length: int = 128) -> Dict:
    """
    Tokenize a list of texts.
    """
    def tokenize_batch(examples):
        return tokenizer(
            examples,
            truncation=True,
            max_length=max_length,
            padding='max_length',
            return_tensors='pt'
        )
    
    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=max_length,
        padding='max_length',
        return_tensors='pt'
    )
    
    return tokenized


def train_transformer(
    texts: List[str],
    model_name: str,
    output_dir: str,
    num_train_epochs: int = 3,
    batch_size: int = 8,
    learning_rate: float = 5e-5,
    max_length: int = 128,
    save_steps: int = 5000,
    eval_steps: int = 5000
):
    """
    Train a causal language model (transformer) for next-token prediction.
    """
    print(f"\n{'='*60}")
    print(f"Training transformer: {model_name}")
    print(f"{'='*60}")
    
    # Initialize tokenizer and model
    print("Initializing tokenizer and model...")
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    
    # Use GPT2 config but make it smaller
    config = GPT2Config(
        vocab_size=len(tokenizer),
        n_positions=max_length,
        n_ctx=max_length,
        n_embd=256,
        n_layer=4,
        n_head=4,
        n_inner=1024,
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
    )
    
    model = GPT2LMHeadModel(config)
    
    # Tokenize texts
    print("Tokenizing texts...")
    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            truncation=True,
            max_length=max_length,
            padding='max_length'
        )
    
    dataset = HFDataset.from_dict({'text': texts})
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=['text']
    )
    
    # Data collator for language modeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False  # Causal LM, not masked LM
    )
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=learning_rate,
        warmup_steps=100,
        logging_steps=100,
        save_steps=save_steps,
        eval_steps=eval_steps,
        eval_strategy="steps",
        save_total_limit=2,
        prediction_loss_only=True,
        load_best_model_at_end=True,
        metric_for_best_model="loss",
        greater_is_better=False,
    )
    
    # Split dataset
    split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
    
    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=split_dataset['train'],
        eval_dataset=split_dataset['test'],
    )
    
    # Train
    print("Starting training...")
    trainer.train()
    
    # Save model and tokenizer
    print(f"Saving model to {output_dir}...")
    trainer.save_model()
    tokenizer.save_pretrained(output_dir)
    
    print("Training completed!")
    return model, tokenizer


def train_sae(
    model: nn.Module,
    tokenizer,
    test_texts: List[str],
    hidden_layer_idx: int = -1,  # Last layer
    n_features: int = 8192,
    sparsity_weight: float = 1e-3,
    batch_size: int = 16,
    num_epochs: int = 10,
    learning_rate: float = 1e-3,
    max_length: int = 128,
    device: str = None
) -> SparseAutoencoder:
    """
    Train a sparse autoencoder on the hidden states of the model.
    
    Args:
        model: The trained transformer model
        tokenizer: The tokenizer
        test_texts: List of test texts to extract activations from
        hidden_layer_idx: Which layer to extract activations from (-1 = last layer)
        n_features: Number of features in SAE
        sparsity_weight: Weight for sparsity penalty
        batch_size: Training batch size
        num_epochs: Number of training epochs
        learning_rate: Learning rate for SAE training
        max_length: Maximum sequence length
        device: Device to use (cuda/cpu)
    
    Returns:
        Trained sparse autoencoder
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"\n{'='*60}")
    print(f"Training Sparse Autoencoder")
    print(f"{'='*60}")
    print(f"Device: {device}")
    print(f"Hidden layer index: {hidden_layer_idx}")
    print(f"Number of SAE features: {n_features}")
    print(f"Number of test texts: {len(test_texts)}")
    
    model.to(device)
    model.eval()
    
    # Extract hidden states from the model
    print("Extracting hidden states from model...")
    all_hidden_states = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_texts), batch_size), desc="Extracting activations"):
            batch_texts = test_texts[i:i+batch_size]
            
            # Tokenize
            inputs = tokenizer(
                batch_texts,
                truncation=True,
                max_length=max_length,
                padding='max_length',
                return_tensors='pt'
            ).to(device)
            
            # Forward pass and extract hidden states
            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states[hidden_layer_idx]  # (batch, seq, hidden_dim)
            
            # Take mean over sequence dimension (or use last token)
            # Using mean pooling
            pooled_hidden = hidden_states.mean(dim=1)  # (batch, hidden_dim)
            
            all_hidden_states.append(pooled_hidden.cpu())
    
    # Concatenate all hidden states
    all_hidden_states = torch.cat(all_hidden_states, dim=0)  # (n_samples, hidden_dim)
    hidden_dim = all_hidden_states.shape[-1]
    
    print(f"Extracted hidden states shape: {all_hidden_states.shape}")
    print(f"Hidden dimension: {hidden_dim}")
    
    # Initialize SAE
    sae = SparseAutoencoder(
        input_dim=hidden_dim,
        n_features=n_features,
        sparsity_weight=sparsity_weight
    ).to(device)
    
    # Create dataset and dataloader
    dataset = torch.utils.data.TensorDataset(all_hidden_states)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    
    # Optimizer
    optimizer = torch.optim.Adam(sae.parameters(), lr=learning_rate)
    
    # Training loop
    print(f"\nTraining SAE for {num_epochs} epochs...")
    sae.train()
    
    for epoch in range(num_epochs):
        epoch_losses = []
        epoch_recon_losses = []
        epoch_sparsity_losses = []
        
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            hidden_states = batch[0].to(device)
            
            # Forward pass
            loss_dict = sae.compute_loss(hidden_states)
            loss = loss_dict['loss']
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss.item())
            epoch_recon_losses.append(loss_dict['reconstruction_loss'].item())
            epoch_sparsity_losses.append(loss_dict['sparsity_loss'].item())
        
        avg_loss = np.mean(epoch_losses)
        avg_recon = np.mean(epoch_recon_losses)
        avg_sparsity = np.mean(epoch_sparsity_losses)
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f} "
              f"(Recon: {avg_recon:.4f}, Sparsity: {avg_sparsity:.4f})")
    
    print("SAE training completed!")
    return sae


def generate_sae_interpretations(
    model: nn.Module,
    tokenizer,
    sae: SparseAutoencoder,
    test_texts: List[str],
    batch_size: int = 16,
    max_length: int = 128,
    device: str = None
) -> List[Dict]:
    """
    Generate SAE interpretations for test instances.
    
    Returns:
        List of dictionaries with interpretations for each test instance
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"\n{'='*60}")
    print(f"Generating SAE Interpretations")
    print(f"{'='*60}")
    
    model.to(device)
    model.eval()
    sae.to(device)
    sae.eval()
    
    interpretations = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_texts), batch_size), desc="Generating interpretations"):
            batch_texts = test_texts[i:i+batch_size]
            
            # Tokenize
            inputs = tokenizer(
                batch_texts,
                truncation=True,
                max_length=max_length,
                padding='max_length',
                return_tensors='pt'
            ).to(device)
            
            # Get hidden states
            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]  # Last layer
            pooled_hidden = hidden_states.mean(dim=1)  # (batch, hidden_dim)
            
            # Get SAE features
            features, reconstructed = sae(pooled_hidden)
            features = features.cpu().numpy()  # (batch, n_features)
            
            # Get top-k active features for each instance
            k = 10
            top_k_features = np.argsort(features, axis=1)[:, -k:]  # Top k most active
            top_k_values = np.sort(features, axis=1)[:, -k:]  # Corresponding values
            
            # Store interpretations
            for j, text in enumerate(batch_texts):
                interpretations.append({
                    'text': text,
                    'top_features': [int(x) for x in top_k_features[j].tolist()],
                    'feature_values': [float(x) for x in top_k_values[j].tolist()],
                    'n_active_features': int((features[j] > 0.01).sum()),  # Features above threshold
                    'mean_activation': float(features[j].mean()),
                    'max_activation': float(features[j].max()),
                })
    
    print(f"Generated interpretations for {len(interpretations)} instances")
    return interpretations


def save_interpretations(interpretations: List[Dict], output_path: str):
    """Save interpretations to JSON file."""
    print(f"Saving interpretations to {output_path}...")
    with open(output_path, 'w') as f:
        json.dump(interpretations, f, indent=2)
    print("Saved!")


def main():
    parser = argparse.ArgumentParser(description='Train transformers and apply SAE interpretations')
    parser.add_argument('--data_path', type=str, 
                       default='./train.csv',
                       help='Path to training CSV file')
    parser.add_argument('--test_path', type=str, 
                       default='./test.csv',
                       help='Path to test CSV file for SAE training and interpretation')
    parser.add_argument('--test_labels_path', type=str, 
                       default='./test_labels.csv',
                       help='Path to test label CSV file')
    parser.add_argument('--model1_dir', type=str, default='./transformer_toxic',
                       help='Output directory for first model (trained on all data)')
    parser.add_argument('--model2_dir', type=str, default='./transformer_non_toxic',
                       help='Output directory for second model (trained on non-toxic only)')
    parser.add_argument('--n_sae_features', type=int, default=8192,
                       help='Number of features in sparse autoencoder')
    parser.add_argument('--sparsity_weight', type=float, default=1e-3,
                       help='Weight for sparsity penalty in SAE')
    parser.add_argument('--num_epochs', type=int, default=3,
                       help='Number of epochs for transformer training')
    parser.add_argument('--sae_epochs', type=int, default=10,
                       help='Number of epochs for SAE training')
    parser.add_argument('--batch_size', type=int, default=8,
                       help='Batch size for transformer training')
    parser.add_argument('--sae_batch_size', type=int, default=16,
                       help='Batch size for SAE training')
    parser.add_argument('--learning_rate', type=float, default=5e-5,
                       help='Learning rate for transformer training')
    parser.add_argument('--sae_learning_rate', type=float, default=1e-3,
                       help='Learning rate for SAE training')
    parser.add_argument('--max_length', type=int, default=128,
                       help='Maximum sequence length')
    parser.add_argument('--device', type=str, default=None,
                       help='Device to use (cuda/cpu). Auto-detect if not specified')
    parser.add_argument('--skip_model1_training', action='store_true',
                       help='Skip training model 1 (assume already trained)')
    parser.add_argument('--skip_model2_training', action='store_true',
                       help='Skip training model 2 (assume already trained)')
    
    args = parser.parse_args()
    
    # Set device
    if args.device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = args.device
    print(f"Using device: {device}")
    
    # Load training data
    df = load_data(args.data_path)
    
    # Load test data for SAE training and interpretation
    test_df = load_test_data(csv_path=args.test_path, labels_path=args.test_labels_path)
    test_texts = test_df['comment_text'].tolist()
    
    # ========== Model 1: Train on all data ==========
    if not args.skip_model1_training:
        toxic_texts = prepare_texts(df, filter_toxic=True)
        model1, tokenizer1 = train_transformer(
            texts=toxic_texts,
            model_name='transformer_toxic',
            output_dir=args.model1_dir,
            num_train_epochs=args.num_epochs,
            batch_size=args.batch_size,
            learning_rate=args.learning_rate,
            max_length=args.max_length
        )
    else:
        print("Loading model 1 from disk...")
        tokenizer1 = AutoTokenizer.from_pretrained(args.model1_dir)
        model1 = GPT2LMHeadModel.from_pretrained(args.model1_dir)
    
    # Train SAE for model 1
    print("\n" + "="*60)
    print("MODEL 1: Training SAE")
    print("="*60)
    sae1 = train_sae(
        model=model1,
        tokenizer=tokenizer1,
        test_texts=test_texts,
        n_features=args.n_sae_features,
        sparsity_weight=args.sparsity_weight,
        batch_size=args.sae_batch_size,
        num_epochs=args.sae_epochs,
        learning_rate=args.sae_learning_rate,
        max_length=args.max_length,
        device=device
    )
    
    # Generate interpretations for model 1
    interpretations1 = generate_sae_interpretations(
        model=model1,
        tokenizer=tokenizer1,
        sae=sae1,
        test_texts=test_texts,
        batch_size=args.sae_batch_size,
        max_length=args.max_length,
        device=device
    )
    save_interpretations(interpretations1, './sae_interpretations_model1.json')
    
    # ========== Model 2: Train on non-toxic instances only ==========
    if not args.skip_model2_training:
        non_toxic_texts = prepare_texts(df, filter_non_toxic=True)
        model2, tokenizer2 = train_transformer(
            texts=non_toxic_texts,
            model_name='transformer_non_toxic',
            output_dir=args.model2_dir,
            num_train_epochs=args.num_epochs,
            batch_size=args.batch_size,
            learning_rate=args.learning_rate,
            max_length=args.max_length
        )
    else:
        print("Loading model 2 from disk...")
        tokenizer2 = AutoTokenizer.from_pretrained(args.model2_dir)
        model2 = GPT2LMHeadModel.from_pretrained(args.model2_dir)
    
    # Train SAE for model 2
    print("\n" + "="*60)
    print("MODEL 2: Training SAE")
    print("="*60)
    sae2 = train_sae(
        model=model2,
        tokenizer=tokenizer2,
        test_texts=test_texts,
        n_features=args.n_sae_features,
        sparsity_weight=args.sparsity_weight,
        batch_size=args.sae_batch_size,
        num_epochs=args.sae_epochs,
        learning_rate=args.sae_learning_rate,
        max_length=args.max_length,
        device=device
    )
    
    # Generate interpretations for model 2
    interpretations2 = generate_sae_interpretations(
        model=model2,
        tokenizer=tokenizer2,
        sae=sae2,
        test_texts=test_texts,
        batch_size=args.sae_batch_size,
        max_length=args.max_length,
        device=device
    )
    save_interpretations(interpretations2, './sae_interpretations_model2.json')
    
    print("\n" + "="*60)
    print("ALL TASKS COMPLETED!")
    print("="*60)
    print(f"Model 1 (all data) {len(interpretations1)} interpretations saved to: ./sae_interpretations_model1.json")
    print(f"Model 2 (non-toxic only) {len(interpretations2)} interpretations saved to: ./sae_interpretations_model2.json")


if __name__ == '__main__':
    main()
