import json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional, Set
import pandas as pd
import argparse
from sklearn.metrics import roc_curve, auc
import seaborn as sns
import math
import re

plt.rcParams.update({
    'font.family': 'STIXGeneral',
    'mathtext.fontset': 'stix'
})

# Core color scheme
COLORS = {
    'Trained': '#808080',
    'Randomized excl emb': '#1f77b4',
    'Randomized incl emb': '#ff7f0e',
    'Step 0': '#2ca02c',
    'Control': '#000000'
}

def get_variant_label(variant_name: str) -> str:
    """Convert variant directory name to readable label"""
    if 'random_control' in variant_name:
        return 'Control'
    elif 'rerandomised_embeddings' in variant_name.lower() or 'rerandomize_embeddings' in variant_name.lower():
        return 'Randomized incl emb'
    elif any(x in variant_name.lower() for x in ['rerandomise', 'rerandomize']):
        return 'Randomized excl emb'
    elif 'step0' in variant_name:
        return 'Step 0'
    return 'Trained'

def parse_model_size(model_name: str) -> int:
    """Parse model size in millions of parameters"""
    if 'gemma' in model_name.lower():
        return 2000
    elif 'llama' in model_name.lower():
        return 8000
    elif 'pythia' in model_name.lower():
        size_str = model_name.split('-')[1].lower()
        return int(float(size_str.replace('b', '')) * 1000) if 'b' in size_str else int(size_str.replace('m', ''))
    raise ValueError(f"Unknown model format: {model_name}")

def extract_token_count(variant_dir_name: str) -> str:
    """Extract token count from directory name (e.g. '100M', '1000M')"""
    match = re.search(r'(\d+)M', variant_dir_name)
    if match:
        return match.group(1) + "M"
    return "unknown"

def get_all_token_counts(model_dir: Path) -> Set[str]:
    """Get all token counts from variants in a model directory"""
    token_counts = set()
    for variant_dir in model_dir.glob("*"):
        if not variant_dir.is_dir() or variant_dir.name == "explanations":
            continue
        token_count = extract_token_count(variant_dir.name)
        if token_count != "unknown":
            token_counts.add(token_count)
    return token_counts

def load_layer_metrics(eval_dir: Path, metric_name: str) -> Dict:
    """Load metrics for each layer from evaluation results"""
    metrics = {}
    for result_file in eval_dir.glob(f"*{metric_name}*.json"):
        try:
            with open(result_file) as f:
                data = json.load(f)
                layer = int(result_file.stem.split('_layer_')[1].split('_')[0])
                metrics[layer] = data['eval_result_metrics']
        except (json.JSONDecodeError, KeyError, ValueError):
            continue
    return metrics

def compute_roc_data(feature_data: List) -> Tuple[np.ndarray, np.ndarray, float]:
    """Compute ROC curve data from feature scores"""
    y_true = []
    y_prob = []
    for feature_scores in feature_data:
        y_true.extend(item['ground_truth'] for item in feature_scores)
        y_prob.extend(item['probability'] for item in feature_scores)

    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
    
    # Interpolate to standard points for consistent plotting
    standard_fpr = np.linspace(0, 1, 100)
    tpr_interp = np.interp(standard_fpr, fpr, tpr)
    
    return standard_fpr, tpr_interp, roc_auc

def load_roc_data(model_dir: Path, scorer_type: str = 'fuzz') -> Dict:
    """Load ROC curve data for all variants and layers"""
    data = {}
    for variant_dir in model_dir.glob("*"):
        if not variant_dir.is_dir() or variant_dir.name == "explanations":
            continue

        variant_data = {}
        explanation_dir = variant_dir / "explanations"
        if not explanation_dir.exists():
            continue
            
        for layer_dir in explanation_dir.glob("layer_*"):
            layer = int(layer_dir.name.split('_')[1])
            scores_dir = layer_dir / scorer_type
            
            if not scores_dir.exists():
                continue

            layer_data = []
            for score_file in scores_dir.glob("*_score.txt"):
                try:
                    layer_data.append(json.loads(score_file.read_text()))
                except json.JSONDecodeError:
                    continue

            if layer_data:
                variant_data[layer] = layer_data

        if variant_data:
            data[variant_dir.name] = variant_data

    return data

def plot_roc_curves(model_dir: Path, save_dir: Path, scorer_type: str = 'fuzz', 
                 token_count: Optional[str] = None, cache_dir: Optional[Path] = None, 
                 use_cache: bool = False):
    """Create ROC curve plots for each layer"""
    
    # Define cache filename
    cache_filename = None
    if cache_dir is not None:
        cache_key = f"{model_dir.name}_{scorer_type}" + (f"_{token_count}" if token_count else "")
        cache_filename = cache_dir / f"{cache_key}_roc_data.json"
        
        # Try to load from cache
        if use_cache and cache_filename.exists():
            try:
                print(f"Loading cached ROC data from {cache_filename}")
                with open(cache_filename, 'r') as f:
                    cached_data = json.load(f)
                    
                # Convert string keys back to integers for layer indices
                for variant in cached_data.keys():
                    cached_data[variant] = [(float(auc), int(layer)) for auc, layer in cached_data[variant]]
                    
                return cached_data
            except (json.JSONDecodeError, KeyError) as e:
                print(f"Error loading cache: {e}, recomputing ROC data")
    
    # If cache miss or no cache, compute the data
    roc_data = load_roc_data(model_dir, scorer_type)
    if not roc_data:
        return None

    # Filter by token count if specified
    if token_count:
        roc_data = {k: v for k, v in roc_data.items() if token_count in k}
        if not roc_data:
            return None

    # Find all layers across variants
    all_layers = sorted(set(
        layer for variant_data in roc_data.values() 
        for layer in variant_data.keys()
    ))

    print(f"Found {len(all_layers)} layers for {model_dir.name}" + 
          (f" with {token_count} tokens" if token_count else ""))
          
    # Calculate grid dimensions
    n_layers = min(len(all_layers), 8)
    n_cols = min(4, n_layers)
    n_rows = math.ceil(n_layers / n_cols)
    
    # Select the first and last layer, and then at most 6 layers in between
    if len(all_layers) > 8:
        all_layers = [all_layers[0]] + all_layers[1:-1:math.ceil((len(all_layers)-2)/6)] + [all_layers[-1]]

    print(f"Plotting {len(all_layers)} layers with {n_rows} rows and {n_cols} columns")

    fig = plt.figure(figsize=(6.5, 3.5))
    gs = fig.add_gridspec(n_rows, n_cols, hspace=0.4, wspace=0.05)

    aucs_by_variant = {variant: [] for variant in roc_data.keys()}

    for idx, layer in enumerate(all_layers):
        row = idx // n_cols
        col = idx % n_cols
        ax = fig.add_subplot(gs[row, col])
        
        for variant_name, variant_data in roc_data.items():
            if layer not in variant_data:
                continue
                
            fpr, tpr, roc_auc = compute_roc_data(variant_data[layer])
            variant_label = get_variant_label(variant_name)
            
            ax.plot(fpr, tpr, color=COLORS[variant_label], 
                   label=f'{variant_label} (AUC = {roc_auc:.2f})',
                   linewidth=1.5)
            
            aucs_by_variant[variant_name].append((roc_auc, layer))

        ax.plot([0, 1], [0, 1], 'k--', lw=1)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_title(f'Layer {layer + 1}', fontsize=9)
        
        # Set custom ticks
        ax.set_xticks([0, 0.5, 1])
        ax.set_yticks([0, 0.5, 1])

        # Hide all tick labels by default
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params(axis='both', which='both', length=2)
        
        # Only show labels for leftmost and bottom plots
        if col == 0:
            ax.set_ylabel('True Positive Rate', fontsize=8)
        else:
            ax.set_yticklabels([])
            
        if row == n_rows - 1:
            ax.set_xlabel('False Positive Rate', fontsize=8)
        else:
            ax.set_xticklabels([])

    # Add legend at the bottom
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='center', bbox_to_anchor=(0.5, 0.0),
              ncol=3, fontsize=8, frameon=True)
    # Adjust spacing
    plt.subplots_adjust(bottom=0.2, wspace=0.05, hspace=0.2)
    
    plot_type = f"{scorer_type}_roc"
    title = f"{scorer_type.capitalize()} ROC Curves"
    if token_count:
        title += f" ({token_count} tokens)"
    fig.suptitle(title)
    
    # Add token count to filename if specified
    filename = f"{plot_type}_curves" 
    if token_count:
        filename += f"_{token_count}"
    
    plt.savefig(save_dir / f"{filename}.png", dpi=300, bbox_inches='tight')
    plt.savefig(save_dir / f"{filename}.pdf", dpi=300, bbox_inches='tight')
    plt.close()

    # Save to cache if enabled
    if cache_filename is not None:
        # Convert to serializable format
        cache_data = {variant: [(float(auc), int(layer)) for auc, layer in data] 
                     for variant, data in aucs_by_variant.items()}
        
        with open(cache_filename, 'w') as f:
            json.dump(cache_data, f)
        print(f"Saved ROC data to cache: {cache_filename}")

    return aucs_by_variant

def plot_metrics_by_layer(model_dirs: List[Path], save_dir: Path, 
                          aucs_by_model: Optional[Dict] = None, 
                          detection_aucs_by_model: Optional[Dict] = None,
                          token_count: Optional[str] = None):
    """Create plots showing metrics across layers for different model sizes and variants"""
    if token_count == '1000M':
        metrics = {
            'Reconstruction': [
                ('reconstruction_quality', 'explained_variance', 'Explained Variance (R²)'),
                ('reconstruction_quality', 'cossim', 'Cosine Similarity')
            ],
            'Sparsity': [
                ('sparsity', 'l1', 'L1 Norm')
            ],
            'Model': [
                ('AUC', 'auc', 'AUROC (Fuzz)'),
                ('AUC_detection', 'auc', 'AUROC (Detection)'),
                ('model_performance_preservation', 'ce_loss_score', 'CE Loss Score (Trained)')
            ],
            'Complexity': [
                ('token_entropy', 'entropy', 'Token Distribution Entropy')
            ]
        }
    else:
        metrics = {
            'Model': [
                ('AUC', 'auc', 'AUROC (Fuzz)'),
                ('AUC_detection', 'auc', 'AUROC (Detection)'),]
        }

    model_sizes = [(parse_model_size(d.name), d) for d in model_dirs]
    #model_sizes.sort()

    # Calculate total number of metrics
    total_metrics = sum(len(metric_list) for metric_list in metrics.values())
    
    fig, axes = plt.subplots(total_metrics, len(model_sizes), 
                            figsize=(6.5, 1.5 * total_metrics),
                            squeeze=False)

    current_row = 0
    for category_name, metric_list in metrics.items():
        for group, key, label in metric_list:
            for size_idx, (size, model_dir) in enumerate(model_sizes):
                ax = axes[current_row, size_idx]
                
                if group == 'AUC' and aucs_by_model and size in aucs_by_model:
                    # Plot AUC values for fuzz scoring
                    model_aucs = aucs_by_model[size]
                    for variant_name, auc_data in model_aucs.items():
                        # Skip if token count doesn't match
                        if token_count and token_count not in variant_name:
                            continue
                        
                        if not auc_data:
                            continue
                            
                        auc_data.sort(key=lambda x: x[1])  # Sort by layer
                        auc_values, layers = zip(*auc_data)
                        variant_label = get_variant_label(variant_name)
                        ax.plot(layers, auc_values, color=COLORS[variant_label],
                               label=variant_label, linewidth=1.5)
                
                elif group == 'AUC_detection' and detection_aucs_by_model and size in detection_aucs_by_model:
                    # Plot AUC values for detection scoring
                    model_aucs = detection_aucs_by_model[size]
                    for variant_name, auc_data in model_aucs.items():
                        # Skip if token count doesn't match
                        if token_count and token_count not in variant_name:
                            continue
                            
                        if not auc_data:
                            continue
                            
                        auc_data.sort(key=lambda x: x[1])  # Sort by layer
                        auc_values, layers = zip(*auc_data)
                        variant_label = get_variant_label(variant_name)
                        ax.plot(layers, auc_values, color=COLORS[variant_label],
                               label=variant_label, linewidth=1.5)
                
                elif group == 'token_entropy':
                    # Handle token entropy data
                    for variant_dir in model_dir.glob("*"):
                        if not variant_dir.is_dir() or variant_dir.name == "explanations":
                            continue
                            
                        # Skip if token count doesn't match
                        if token_count and token_count not in variant_dir.name:
                            continue
                            
                        explanation_dir = variant_dir / "explanations"
                        if not explanation_dir.exists():
                            continue
                            
                        layer_scores = []
                        for layer_dir in explanation_dir.glob("layer_*"):
                            layer = int(layer_dir.name.split('_')[1])
                            scores_dir = layer_dir / group
                            
                            if scores_dir.exists():
                                scores = []
                                for score_file in scores_dir.glob("*_score.txt"):
                                    try:
                                        with open(score_file, 'r') as f:
                                            score_data = json.load(f)
                                            if key in score_data:
                                                scores.append(score_data[key])
                                    except (json.JSONDecodeError, KeyError):
                                        continue
                                        
                                if scores:
                                    avg_score = sum(scores) / len(scores)
                                    layer_scores.append((layer, avg_score))
                                    
                        if layer_scores:
                            layer_scores.sort(key=lambda x: x[0])
                            layers, scores = zip(*layer_scores)
                            layers = [l + 1 for l in layers]  # Convert to 1-indexed
                            variant_label = get_variant_label(variant_dir.name)
                            ax.plot(layers, scores, color=COLORS[variant_label],
                                   label=variant_label, linewidth=1.5)
                
                else:
                    # Plot other metrics from sae_bench
                    for variant_dir in model_dir.glob("*"):
                        if not variant_dir.is_dir() or variant_dir.name == "explanations":
                            continue

                        # Skip if token count doesn't match
                        if token_count and token_count not in variant_dir.name:
                            continue
                            
                        # Skip non-trained variants for CE loss
                        if key == 'ce_loss_score' and get_variant_label(variant_dir.name) != 'Trained':
                            continue
                            
                        eval_dir = variant_dir / "sae_bench" / "core"
                        if not eval_dir.exists():
                            continue

                        metrics_data = load_layer_metrics(eval_dir, "custom_sae_eval_results")
                        if not metrics_data:
                            continue

                        layers, values = [], []
                        metrics = sorted(metrics_data.items(), key=lambda x: x[0])
                        for layer, data in metrics:
                            try:
                                value = data[group][key]
                                layers.append(layer)
                                values.append(value)
                            except KeyError:
                                continue

                        if layers and values:
                            variant_label = get_variant_label(variant_dir.name)
                            layers = [l + 1 for l in layers]  # Convert to 1-indexed
                            ax.plot(layers, values, color=COLORS[variant_label], 
                                   label=variant_label, linewidth=1.5)

                # Customize subplot
                if size_idx == 0:
                    ax.set_ylabel(label, fontsize=8)
                if current_row == total_metrics - 1:
                    ax.set_xlabel('Layer', fontsize=8)
                if current_row == 0:
                    if size == 2000:
                        ax.set_title("Gemma 2B", fontsize=9)
                    elif size == 8000:
                        ax.set_title("Llama 8B", fontsize=9)
                    else:
                        ax.set_title(f"{size/1000:.1f}B" if size >= 1000 else f"{size}M", 
                                fontsize=9)

                if 'norm' in label.lower():
                    ax.set_yscale('log')

                ax.grid(False)
                ax.tick_params(axis='both', which='major', labelsize=7)
            
            current_row += 1  # Move to next row after each metric

    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='center', bbox_to_anchor=(0.5, 0),
              ncol=len(labels), fontsize=8, frameon=True)

    # Add token count to title and filename
    title = "Metrics By Layer"
    filename = "metrics_by_layer"
    if token_count:
        title += f" ({token_count} tokens)"
        filename += f"_{token_count}"
        
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_dir / f"{filename}.png", dpi=300, bbox_inches='tight')
    plt.savefig(save_dir / f"{filename}.pdf", dpi=300, bbox_inches='tight')
    plt.close()
    
def main():
    parser = argparse.ArgumentParser(description='Plot SAE evaluation metrics')
    parser.add_argument('--eval-dir', type=str, default='saved_eval',
                      help='Directory containing evaluation results')
    parser.add_argument('--cache-dir', type=str, default='roc_cache',
                      help='Directory to cache ROC curve data')
    parser.add_argument('--use-cache', action='store_true',
                      help='Use cached ROC data if available')
    args = parser.parse_args()

    eval_dir = Path(args.eval_dir)
    plots_dir = Path("plots")
    plots_dir.mkdir(exist_ok=True)
    
    # Create cache directory if needed
    cache_dir = Path(args.cache_dir)
    if args.use_cache:
        cache_dir.mkdir(exist_ok=True)

    # Find model directories
    token_model_dirs = {'100M' :[
        "/user/work/cp20141/repos/sae/experiments/saved_eval/pythia-70m-deduped_64_k32", 
        "/user/work/cp20141/repos/sae/experiments/saved_eval/pythia-160m-deduped_64_k32",
        "/user/work/cp20141/repos/sae/experiments/saved_eval/pythia-410m-deduped_64_k32", 
        "/user/work/cp20141/repos/sae/experiments/saved_eval/pythia-1b-deduped_64_k32",
        "/user/work/cp20141/repos/sae/experiments/saved_eval/pythia-6.9b-deduped_64_k32",
        "/user/work/cp20141/repos/sae/experiments/saved_eval/gemma-2-2b_64_k32"],
                '1000M' : [
                    "/user/work/cp20141/repos/sae/experiments/saved_eval/pythia-70m-deduped_64_k32",
                    "/user/work/cp20141/repos/sae/experiments/saved_eval/pythia-410m-deduped_64_k32"]} 
    
    
    
    
        
        
        
    # Generate separate plots for each token count
    for token_count in ['100M', '1000M']:
        print(f"\nProcessing token count: {token_count}")
        model_dirs = token_model_dirs[token_count]

        model_dirs = [Path(d) for d in model_dirs]
        print(f"Found {len(model_dirs)} model directories: {model_dirs}")
        token_fuzz_aucs = {}
        token_detection_aucs = {}
        
        print(f"\nProcessing {token_count} token evaluations...")
        token_plots_dir = plots_dir / f"token_{token_count}"
        token_plots_dir.mkdir(parents=True, exist_ok=True)
        
        for model_dir in model_dirs:
            size = parse_model_size(model_dir.name)
            model_token_plots_dir = token_plots_dir / model_dir.name
            model_token_plots_dir.mkdir(parents=True, exist_ok=True)
            
            # Generate ROC curves filtered by token count with caching
            fuzz_aucs = plot_roc_curves(model_dir, model_token_plots_dir, 'fuzz', token_count,
                                        cache_dir=cache_dir if args.use_cache else None,
                                        use_cache=args.use_cache)
            
            detection_aucs = plot_roc_curves(model_dir, model_token_plots_dir, 'detection', token_count,
                                            cache_dir=cache_dir if args.use_cache else None,
                                            use_cache=args.use_cache)
            
            if fuzz_aucs:
                token_fuzz_aucs[size] = fuzz_aucs
            if detection_aucs:
                token_detection_aucs[size] = detection_aucs
        
        print(token_fuzz_aucs.keys())

        # Plot metrics for this token count
        plot_metrics_by_layer(model_dirs, token_plots_dir, token_fuzz_aucs, token_detection_aucs, token_count)
            
        
if __name__ == "__main__":
    main()