from pathlib import Path
import re
import json
from typing import Dict, List, NamedTuple
from collections import defaultdict
import pandas as pd

class FeatureExample(NamedTuple):
    text: str
    activation: float
    active_tokens: str

class FeatureData(NamedTuple):
    feature_id: str
    interpretation: str
    examples: List[FeatureExample]
    model: str
    variant: str
    layer: int

def escape_latex(text: str) -> str:
    """Escape special LaTeX characters."""
    special_chars = {
        '&': '\\&',
        '%': '\\%',
        '$': '\\$',
        '#': '\\#',
        '_': '\\_',
        '{': '\\{',
        '}': '\\}',
        '~': '\\textasciitilde{}',
        '^': '\\textasciicircum{}',
        '\\': '\\textbackslash{}',
    }
    return ''.join(special_chars.get(c, c) for c in text)

def get_model_variant_label(variant_name: str) -> str:
    """Convert variant directory name to readable label with updated names"""
    if 'random_control' in variant_name:
        return 'Control'
    elif any(x in variant_name.lower() for x in ['rerandomised_embeddings', 'rerandomize_embeddings']):
        return 'Randomized incl emb'
    elif any(x in variant_name.lower() for x in ['rerandomise', 'rerandomize', 'rerandomized', 'rerandomised']) and 'embeddings' not in variant_name.lower():
        return 'Randomized excl emb'
    elif 'step0' in variant_name:
        return 'Step 0'
    else:
        return 'Trained'

def parse_feature_file(file_path: Path) -> FeatureData:
    """Parse a feature interpretation file and extract structured data."""
    with open(file_path, 'r') as f:
        content = f.read().strip()

    # Extract parts from path
    parts = list(file_path.parts)
    
    # Model name is two levels up from explanations directory
    model = parts[-5]  # e.g. pythia-6.9b-deduped_64_k32
    
    # Get dataset directory which includes the variant info
    dataset_dir = parts[-4]  # e.g. redpajama-data-1t-sample_plain_text_100M_step0
    variant = get_model_variant_label(dataset_dir)
    
    feature_id = file_path.stem
    
    # Extract layer from feature filename
    layer_match = re.match(r'feature_(\d+)_', feature_id)
    layer = int(layer_match.group(1))


    # Split content into sections
    sections = content.split('=' * 80)
    
    # Extract interpretation
    interpretation = sections[1].strip() if len(sections) > 1 else ""
    
    # Extract examples
    examples = []
    if len(sections) > 2:
        example_section = sections[2]
        example_blocks = example_section.split('-' * 40)
        
        for block in example_blocks:
            if not block.strip():
                continue
                
            # Extract components using regex
            text_match = re.search(r'Full text:\n(.*?)(?=\nMax activation:|$)', block, re.DOTALL)
            activation_match = re.search(r'Max activation: ([\d.]+)', block)
            tokens_match = re.search(r'Active tokens: (.*?)(?:\n|$)', block)
            
            if text_match and activation_match:
                examples.append(FeatureExample(
                    text=text_match.group(1).strip(),
                    activation=float(activation_match.group(1)),
                    active_tokens=tokens_match.group(1).strip() if tokens_match else ""
                ))
    
    return FeatureData(feature_id, interpretation, examples, model, variant, layer)

def format_feature_latex(feature: FeatureData) -> str:
    """Format a single feature's data as LaTeX."""
    latex_lines = []
    
    # Feature header
    latex_lines.append(f"\\subsubsection*{{Feature {escape_latex(feature.feature_id)}}}")
    
    # Interpretation
    if feature.interpretation:
        latex_lines.append("\\textbf{Interpretation:}")
        latex_lines.append(escape_latex(feature.interpretation))
        latex_lines.append("")
    
    # Examples
    if feature.examples:
        latex_lines.append("\\textbf{Top Examples:}")
        latex_lines.append("\\begin{enumerate}")
        for example in feature.examples[:3]:  # Limit to top 3 examples
            latex_lines.append("\\item")
            latex_lines.append("\\begin{tabular}{p{0.95\\textwidth}}")
            latex_lines.append(f"Text: {escape_latex(example.text)} \\\\")
            latex_lines.append(f"Activation: {example.activation:.4f} \\\\")
            if example.active_tokens:
                latex_lines.append(f"Active tokens: {escape_latex(example.active_tokens)}")
            latex_lines.append("\\end{tabular}")
        latex_lines.append("\\end{enumerate}")
    
    latex_lines.append("\\hrulefill")
    return "\n".join(latex_lines)

def create_latex_file(model: str, features: List[FeatureData], output_file: Path):
    """Create a LaTeX file for a single model."""
    latex_output = []
    
    # Document preamble
    latex_output.extend([
        "\\documentclass{article}",
        "\\usepackage{fullpage}",
        "\\usepackage{enumitem}",
        "\\usepackage{booktabs}",
        "\\usepackage{tabularx}",
        "\\usepackage{xcolor}",
        "",
        "\\begin{document}",
        "",
        f"\\section*{{Model: {escape_latex(model)}}}",
    ])
    
    # Group by variant and layer
    features.sort(key=lambda x: (x.variant, x.layer))
    current_variant = None
    current_layer = None
    
    for feature in features:
        # Add variant header if needed
        if feature.variant != current_variant:
            current_variant = feature.variant
            latex_output.append(f"\\subsection*{{Variant: {escape_latex(current_variant)}}}")
        
        # Add layer header if needed
        if feature.layer != current_layer:
            current_layer = feature.layer
            latex_output.append(f"\\subsection*{{Layer {current_layer}}}")
        
        # Add feature information
        latex_output.append(format_feature_latex(feature))
        latex_output.append("")
    
    # Close document
    latex_output.append("\\end{document}")
    
    # Save LaTeX file
    with open(output_file, 'w') as f:
        f.write("\n".join(latex_output))

def process_eval_directory(eval_dir: Path, output_dir: Path):
    """Process all feature files in the evaluation directory."""
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Get all model directories first
    model_dirs = sorted(list(eval_dir.glob("pythia-*")))
    print(f"Found {len(model_dirs)} model directories")
    
    for model_dir in model_dirs:
        if not model_dir.is_dir():
            continue
            
        print(f"\nProcessing {model_dir.name}...")
        
        # Dictionary to store features by variant
        variant_features = defaultdict(list)
        
        # Process each variant
        variant_dirs = list(model_dir.glob("*"))
        for variant_dir in variant_dirs:
            if not variant_dir.is_dir() or variant_dir.name == "explanations":
                continue
                
            print(f"  Processing variant: {variant_dir.name}")
            explanations_dir = variant_dir / "explanations"
            if not explanations_dir.exists():
                continue
                
            # Process each layer
            layer_dirs = sorted(explanations_dir.glob("layer_*"))
            for layer_dir in layer_dirs:
                print(f"    Processing {layer_dir.name}")
                
                # Process feature files
                feature_files = list(layer_dir.glob("feature_*.txt"))
                for feature_file in feature_files:
                    data = parse_feature_file(feature_file)
                    variant_features[data.variant].append(data)
        
        # Create separate file for each variant
        model_output_dir = output_dir / model_dir.name
        model_output_dir.mkdir(parents=True, exist_ok=True)
        
        for variant, features in variant_features.items():
            print(f"  Creating LaTeX output for {variant}")
            output_file = model_output_dir / f"{variant}.tex"
            create_latex_file(f"{model_dir.name} - {variant}", features, output_file)
            
            # Create summary statistics for this variant
            stats = {
                'model': model_dir.name,
                'variant': variant,
                'num_layers': len(set(f.layer for f in features)),
                'num_features': len(features),
                'avg_examples_per_feature': sum(len(f.examples) for f in features) / len(features),
            }
            
            # Save statistics
            with open(model_output_dir / f"{variant}_stats.json", 'w') as f:
                json.dump(stats, f, indent=2)
    
    print(f"\nProcessed features saved to {output_dir}")

if __name__ == "__main__":
    eval_dir = Path("saved_eval")
    output_dir = Path("feature_examples")
    
    if not eval_dir.exists():
        print(f"Evaluation directory {eval_dir} not found!")
        exit(1)
        
    process_eval_directory(eval_dir, output_dir)