#!/usr/bin/env python3
"""
Generate LaTeX-formatted tables from ablation experiment results.

Each table shows metrics (M, F, D) for each experience across different
indexed training configurations.
"""

import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from collections import defaultdict

# Ablation definitions (matching run_ablation_full.py)
ABLATIONS = {
    1: {
        "name": "ablate_replay_domain_05",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.05,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    2: {
        "name": "ablate_replay_domain_10",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    3: {
        "name": "ablate_replay_domain_20",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.20,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    4: {
        "name": "ablate_replay_random_05",
        "replay_strategy": "random",
        "replay_percentage": 0.05,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    5: {
        "name": "ablate_replay_random_10",
        "replay_strategy": "random",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    6: {
        "name": "ablate_replay_random_20",
        "replay_strategy": "random",
        "replay_percentage": 0.20,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    7: {
        "name": "ablate_no_emb_anchor",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": False,
        "proj_anchor": True,
    },
    8: {
        "name": "ablate_no_proj_anchor",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": False,
    },
}

# Get ablation names list
ABLATION_NAMES = [ablation["name"] for ablation in ABLATIONS.values()]

# Map ablation names to display names
ABLATION_DISPLAY_NAMES = {
    "ablate_replay_domain_05": "Domain Replay (5%)",
    "ablate_replay_domain_10": "Domain Replay (10%)",
    "ablate_replay_domain_20": "Domain Replay (20%)",
    "ablate_replay_random_05": "Random Replay (5%)",
    "ablate_replay_random_10": "Random Replay (10%)",
    "ablate_replay_random_20": "Random Replay (20%)",
    "ablate_no_emb_anchor": "No Embedding Anchor, 10% Domain Replay",
    "ablate_no_proj_anchor": "No Projection Anchor, 10% Domain Replay",
}

# Experience names in order
EXPERIENCES = ["apibench", "mllm", "hugging-bench-1", "hugging-bench-2"]

# Map experience names to result directory names
EXPERIENCE_DIR_MAP = {
    "apibench": "apibench",
    "mllm": "mllm",
    "hugging-bench-1": "hugging-bench-1",
    "hugging-bench-2": "hugging-bench-2",
}



def find_metrics_file(results_dir: Path, ablation: str, test_experience: str, trained_up_to_exp: int) -> Optional[Path]:
    """
    Find the metrics.json file for a given ablation, test experience, and training configuration.
    
    Uses the naming pattern from run_ablation_full.py: {experience_name}-{ablation_name}
    Results are saved to: results/{test_experience}/{experience_name}-{ablation_name}/checkpoint-XXX/
    
    Args:
        results_dir: Base results directory
        ablation: Ablation name (e.g., "ablate_replay_domain_05")
        test_experience: Experience name to test on (e.g., "apibench", "mllm")
        trained_up_to_exp: Which experience the model was trained up to (1=apibench, 2=mllm, 3=hugging-bench-1, 4=hugging-bench-2)
    
    Returns:
        Path to metrics.json file, or None if not found
    """
    exp_dir = EXPERIENCE_DIR_MAP[test_experience]
    exp_results_dir = results_dir / exp_dir
    
    if not exp_results_dir.exists():
        return None
    
    # Special case: For "Indexed(Exp 1)" (trained_up_to_exp == 1), use the initial apibench checkpoint
    # This is the baseline before any mllm training
    if trained_up_to_exp == 1 and test_experience == "apibench":
        # Try to find the apibench-router_test checkpoint
        apibench_checkpoint_dir = results_dir / "apibench" / "apibench-router_test"
        if apibench_checkpoint_dir.exists():
            metrics_files = list(apibench_checkpoint_dir.rglob("metrics.json"))
            if metrics_files:
                return sorted(metrics_files)[0]
    
    # Determine which adapter directory to use based on trained_up_to_exp
    # trained_up_to_exp == 1: apibench-{ablation_name}
    # trained_up_to_exp == 2: mllm-{ablation_name}
    # trained_up_to_exp == 3: hugging-bench-1-{ablation_name}
    # trained_up_to_exp == 4: hugging-bench-2-{ablation_name}
    
    if trained_up_to_exp == 1:
        adapter_dir_name = f"apibench-{ablation}"
    elif trained_up_to_exp == 2:
        adapter_dir_name = f"mllm-{ablation}"
    elif trained_up_to_exp == 3:
        adapter_dir_name = f"hugging-bench-1-{ablation}"
    elif trained_up_to_exp == 4:
        adapter_dir_name = f"hugging-bench-2-{ablation}"
    else:
        return None
    
    # Look for the adapter directory in the test experience results
    adapter_dir = exp_results_dir / adapter_dir_name
    
    if adapter_dir.exists():
        # Look for metrics.json in checkpoint subdirectories or directly
        metrics_files = list(adapter_dir.rglob("metrics.json"))
        if metrics_files:
            # Prefer checkpoint subdirectories, take the first one found
            return sorted(metrics_files)[0]
        # Also check if metrics.json is directly in the adapter directory
        direct_metrics = adapter_dir / "metrics.json"
        if direct_metrics.exists():
            return direct_metrics
    
    return None


def load_metrics(metrics_path: Path) -> Optional[Dict]:
    """Load metrics from a JSON file."""
    if not metrics_path or not metrics_path.exists():
        return None
    try:
        with open(metrics_path, 'r') as f:
            return json.load(f)
    except Exception as e:
        print(f"Error loading {metrics_path}: {e}")
        return None


def extract_metrics(metrics: Dict) -> Tuple[Optional[float], Optional[float], Optional[float]]:
    """
    Extract M, F, D metrics from a metrics dictionary.
    
    Returns:
        Tuple of (M, F, D) where:
        - M: Model accuracy (top1_accuracy)
        - F: Model family accuracy (prefers model_family_accuracy_all_examples, then model_family_accuracy, then Accuracy Model Family)
        - D: Domain accuracy (domain_accuracy)
    """
    # M = Model accuracy
    m = metrics.get("top1_accuracy") or metrics.get("Accuracy")
    if m is not None:
        m = round(m * 100, 1)
    
    # F = Model family accuracy
    # Prefer model_family_accuracy_all_examples (uses num_valid as denominator, comparable to model accuracy)
    # Then fall back to model_family_accuracy (uses total_family_count as denominator)
    # Finally fall back to Accuracy Model Family (legacy format)
    f = (
        metrics.get("model_family_accuracy_all_examples") or
        metrics.get("model_family_accuracy") or
        metrics.get("Accuracy Model Family")
    )
    if f is not None:
        f = round(f * 100, 1)
    # Note: Not using hier_model_top1 as it's hierarchical routing accuracy, not family accuracy
    
    # D = Domain accuracy
    d = metrics.get("domain_accuracy") or metrics.get("Accuracy Domain")
    if d is not None:
        d = round(d * 100, 1)
    
    return m, f, d


def format_value(value: Optional[float]) -> str:
    """Format a value for the table (empty string if None)."""
    if value is None:
        return ""
    return str(value)


def find_joint_training_metrics(results_dir: Path, joint_name: str, test_experience: str) -> Optional[Path]:
    """
    Find metrics.json for joint training evaluated on a specific experience.
    
    Args:
        results_dir: Root results directory
        joint_name: Joint training variant name (e.g., "joint-joint_training_upper_bound_1536")
        test_experience: Experience to test on (e.g., "apibench")
    
    Returns:
        Path to metrics.json if found, None otherwise
    """
    exp_dir = EXPERIENCE_DIR_MAP.get(test_experience)
    if not exp_dir:
        return None
    
    exp_results_dir = results_dir / exp_dir / joint_name
    if not exp_results_dir.exists():
        return None
    
    # Look for metrics.json in checkpoint subdirectories or directly
    metrics_files = list(exp_results_dir.rglob("metrics.json"))
    if metrics_files:
        return sorted(metrics_files)[0]
    
    # Also check if metrics.json is directly in the run directory
    direct_metrics = exp_results_dir / "metrics.json"
    if direct_metrics.exists():
        return direct_metrics
    
    return None


def generate_joint_training_row(results_dir: Path, joint_name: str) -> List[Optional[float]]:
    """
    Generate a row for joint training (evaluated on all 4 experiences).
    
    Args:
        results_dir: Root results directory
        joint_name: Joint training variant name (e.g., "joint-joint_training_upper_bound_1536")
    
    Returns:
        List of 15 values: 4 experiences * 3 metrics (M, F, D) + 3 BWT (all None for joint training)
    """
    row_data = []
    
    # For each experience, get metrics
    for exp_idx in range(4):
        if exp_idx < len(EXPERIENCES):
            test_experience = EXPERIENCES[exp_idx]
            metrics_path = find_joint_training_metrics(results_dir, joint_name, test_experience)
            
            if metrics_path:
                metrics = load_metrics(metrics_path)
                if metrics:
                    m, f, d = extract_metrics(metrics)
                    row_data.extend([m, f, d])
                else:
                    row_data.extend([None, None, None])
            else:
                row_data.extend([None, None, None])
        else:
            row_data.extend([None, None, None])
    
    # BWT is None for joint training (no sequential training, no forgetting)
    row_data.extend([None, None, None])
    
    return row_data


def generate_table(results_dir: Path, ablation: str) -> str:
    """
    Generate a LaTeX-formatted table for a single ablation.
    
    Returns:
        String containing the formatted table
    """
    # For each indexed configuration:
    # - Indexed(Exp 1): Model trained on Exp 1, tested on Exp 1
    # - Indexed(Exp 1-2): Model trained on Exp 1-2, tested on Exp 1 and Exp 2
    # - Indexed(Exp 1-3): Model trained on Exp 1-3, tested on Exp 1, 2, 3
    # - Indexed(Exp 1-4): Model trained on all, tested on all
    
    indexed_configs = {
        "Indexed(Exp 1)": (1, [0]),  # Trained up to exp 1, test on exp 1
        "Indexed(Exp 1-2)": (2, [0, 1]),  # Trained up to exp 2, test on exp 1, 2
        "Indexed(Exp 1-3)": (3, [0, 1, 2]),  # Trained up to exp 3, test on exp 1, 2, 3
        "Indexed(Exp 1-4)": (4, [0, 1, 2, 3]),  # Trained up to exp 4, test on all
    }
    
    table_data = {}
    
    for indexed_name, (trained_up_to, test_indices) in indexed_configs.items():
        row_data = []
        
        # For each experience (1-4), get metrics if it's in the test set
        for exp_idx in range(4):
            if exp_idx in test_indices and exp_idx < len(EXPERIENCES):
                test_experience = EXPERIENCES[exp_idx]
                # Find metrics for model trained up to trained_up_to, tested on test_experience
                metrics_path = find_metrics_file(results_dir, ablation, test_experience, trained_up_to)
                
                if metrics_path:
                    metrics = load_metrics(metrics_path)
                    if metrics:
                        m, f, d = extract_metrics(metrics)
                        row_data.extend([m, f, d])
                    else:
                        row_data.extend([None, None, None])
                else:
                    row_data.extend([None, None, None])
            else:
                # Not in test set for this configuration
                row_data.extend([None, None, None])
        
        # Calculate Forgetting (BWT) for this row
        # BWT = sum over j from 1 to i-1 of (A[i,j] - A[j,j])
        # Where A[i,j] is performance of model trained on exp 1..i+1, tested on exp j+1
        # And A[j,j] is performance of model trained on exp 1..j+1, tested on exp j+1
        # Note: Despite the name "BWT" (Backward Transfer), this actually measures forgetting
        # We use abs() to ensure forgetting is always positive
        bwt_m, bwt_f, bwt_d = None, None, None
        
        if trained_up_to > 1:  # No BWT for Indexed(Exp 1)
            # Calculate backward transfer for experience trained_up_to (1-indexed)
            bwt_m_sum = 0.0
            bwt_f_sum = 0.0
            bwt_d_sum = 0.0
            valid_m = True
            valid_f = True
            valid_d = True
            
            # For each previous experience j (0-indexed), calculate the difference
            for prev_exp_idx in range(trained_up_to - 1):
                if prev_exp_idx < len(EXPERIENCES):
                    # A[i,j]: row i (trained on exp 1..i+1), tested on exp j+1
                    current_metrics_path = find_metrics_file(results_dir, ablation, EXPERIENCES[prev_exp_idx], trained_up_to)
                    # A[j,j]: row j (trained on exp 1..j+1), tested on exp j+1
                    original_metrics_path = find_metrics_file(results_dir, ablation, EXPERIENCES[prev_exp_idx], prev_exp_idx + 1)
                    
                    if current_metrics_path and original_metrics_path:
                        current_metrics = load_metrics(current_metrics_path)
                        original_metrics = load_metrics(original_metrics_path)
                        
                        if current_metrics and original_metrics:
                            curr_m, curr_f, curr_d = extract_metrics(current_metrics)
                            orig_m, orig_f, orig_d = extract_metrics(original_metrics)
                            
                            if curr_m is not None and orig_m is not None:
                                bwt_m_sum += (curr_m - orig_m)
                            else:
                                valid_m = False
                            if curr_f is not None and orig_f is not None:
                                bwt_f_sum += (curr_f - orig_f)
                            else:
                                valid_f = False
                            if curr_d is not None and orig_d is not None:
                                bwt_d_sum += (curr_d - orig_d)
                            else:
                                valid_d = False
                        else:
                            valid_m = False
                            valid_f = False
                            valid_d = False
                    else:
                        valid_m = False
                        valid_f = False
                        valid_d = False
            
            # Normalize by number of previous experiences
            num_prev_experiences = trained_up_to - 1
            bwt_m = round(abs(bwt_m_sum / num_prev_experiences), 1) if valid_m and num_prev_experiences > 0 else None
            bwt_f = round(abs(bwt_f_sum / num_prev_experiences), 1) if valid_f and num_prev_experiences > 0 else None
            bwt_d = round(abs(bwt_d_sum / num_prev_experiences), 1) if valid_d and num_prev_experiences > 0 else None
        
        row_data.extend([bwt_m, bwt_f, bwt_d])
        table_data[indexed_name] = row_data
    
    # Calculate Mean row
    mean_row = []
    for col_idx in range(15):  # 4 experiences * 3 metrics + 3 BWT = 15 columns
        values = [row[col_idx] for row in table_data.values() if row[col_idx] is not None]
        if values:
            mean_val = round(sum(values) / len(values), 1)
            mean_row.append(mean_val)
        else:
            mean_row.append(None)
    table_data["Mean"] = mean_row
    
    # Format the table
    lines = []
    row_order = ["Indexed(Exp 1)", "Indexed(Exp 1-2)", "Indexed(Exp 1-3)", "Indexed(Exp 1-4)", "Mean"]
    
    for row_name in row_order:
        row_data = table_data[row_name]
        formatted_values = [format_value(v) for v in row_data]
        line = f"    {row_name:<20} & {' & '.join(formatted_values)} \\\\"
        lines.append(line)
    
    return "\n".join(lines)


def generate_joint_training_table(results_dir: Path, joint_name: str) -> str:
    """
    Generate a LaTeX-formatted table for joint training (single row).
    
    Args:
        results_dir: Root results directory
        joint_name: Joint training variant name (e.g., "joint-joint_training_upper_bound_1536")
    
    Returns:
        String containing the formatted table
    """
    row_data = generate_joint_training_row(results_dir, joint_name)
    formatted_values = [format_value(v) for v in row_data]
    line = f"    Joint Training       & {' & '.join(formatted_values)} \\\\"
    return line


def generate_average_summary_table(results_dir: Path, ablation: str) -> str:
    """
    Generate a summary table showing average performance across all experiences.
    
    This function computes the Mean row first, then averages those Mean values
    across experiences (Exp 1-4) for each metric type.
    
    Returns a table with 6 values:
    - Average Model accuracy across Exp 1-4 (from Mean row)
    - Average Family accuracy across Exp 1-4 (from Mean row)
    - Average Domain accuracy across Exp 1-4 (from Mean row)
    - Model BWT (from Mean row, already averaged)
    - Family BWT (from Mean row, already averaged)
    - Domain BWT (from Mean row, already averaged)
    
    Args:
        results_dir: Root results directory
        ablation: Ablation name
    
    Returns:
        String containing the formatted table
    """
    indexed_configs = {
        "Indexed(Exp 1)": (1, [0]),
        "Indexed(Exp 1-2)": (2, [0, 1]),
        "Indexed(Exp 1-3)": (3, [0, 1, 2]),
        "Indexed(Exp 1-4)": (4, [0, 1, 2, 3]),
    }
    
    table_data = {}
    
    # First, compute all rows including the Mean row (same logic as generate_table)
    for indexed_name, (trained_up_to, test_indices) in indexed_configs.items():
        row_data = []
        
        # For each experience (1-4), get metrics if it's in the test set
        for exp_idx in range(4):
            if exp_idx in test_indices and exp_idx < len(EXPERIENCES):
                test_experience = EXPERIENCES[exp_idx]
                metrics_path = find_metrics_file(results_dir, ablation, test_experience, trained_up_to)
                
                if metrics_path:
                    metrics = load_metrics(metrics_path)
                    if metrics:
                        m, f, d = extract_metrics(metrics)
                        row_data.extend([m, f, d])
                    else:
                        row_data.extend([None, None, None])
                else:
                    row_data.extend([None, None, None])
            else:
                row_data.extend([None, None, None])
        
        # Calculate Forgetting (BWT) for this row
        # BWT = sum over j from 1 to i-1 of (A[i,j] - A[j,j])
        # Where A[i,j] is performance of model trained on exp 1..i+1, tested on exp j+1
        # And A[j,j] is performance of model trained on exp 1..j+1, tested on exp j+1
        # Note: Despite the name "BWT" (Backward Transfer), this actually measures forgetting
        # We use abs() to ensure forgetting is always positive
        bwt_m, bwt_f, bwt_d = None, None, None
        
        if trained_up_to > 1:
            # Calculate backward transfer for experience trained_up_to (1-indexed)
            bwt_m_sum = 0.0
            bwt_f_sum = 0.0
            bwt_d_sum = 0.0
            valid_m = True
            valid_f = True
            valid_d = True
            
            for prev_exp_idx in range(trained_up_to - 1):
                if prev_exp_idx < len(EXPERIENCES):
                    # A[i,j]: row i (trained on exp 1..i+1), tested on exp j+1
                    current_metrics_path = find_metrics_file(results_dir, ablation, EXPERIENCES[prev_exp_idx], trained_up_to)
                    # A[j,j]: row j (trained on exp 1..j+1), tested on exp j+1
                    original_metrics_path = find_metrics_file(results_dir, ablation, EXPERIENCES[prev_exp_idx], prev_exp_idx + 1)
                    
                    if current_metrics_path and original_metrics_path:
                        current_metrics = load_metrics(current_metrics_path)
                        original_metrics = load_metrics(original_metrics_path)
                        
                        if current_metrics and original_metrics:
                            curr_m, curr_f, curr_d = extract_metrics(current_metrics)
                            orig_m, orig_f, orig_d = extract_metrics(original_metrics)
                            
                            if curr_m is not None and orig_m is not None:
                                bwt_m_sum += (curr_m - orig_m)
                            else:
                                valid_m = False
                            if curr_f is not None and orig_f is not None:
                                bwt_f_sum += (curr_f - orig_f)
                            else:
                                valid_f = False
                            if curr_d is not None and orig_d is not None:
                                bwt_d_sum += (curr_d - orig_d)
                            else:
                                valid_d = False
                        else:
                            valid_m = False
                            valid_f = False
                            valid_d = False
                    else:
                        valid_m = False
                        valid_f = False
                        valid_d = False
            
            # Normalize by number of previous experiences
            num_prev_experiences = trained_up_to - 1
            bwt_m = round(abs(bwt_m_sum / num_prev_experiences), 1) if valid_m and num_prev_experiences > 0 else None
            bwt_f = round(abs(bwt_f_sum / num_prev_experiences), 1) if valid_f and num_prev_experiences > 0 else None
            bwt_d = round(abs(bwt_d_sum / num_prev_experiences), 1) if valid_d and num_prev_experiences > 0 else None
        
        row_data.extend([bwt_m, bwt_f, bwt_d])
        table_data[indexed_name] = row_data
    
    # Calculate Mean row
    mean_row = []
    for col_idx in range(15):  # 4 experiences * 3 metrics + 3 BWT = 15 columns
        values = [row[col_idx] for row in table_data.values() if row[col_idx] is not None]
        if values:
            mean_val = round(sum(values) / len(values), 1)
            mean_row.append(mean_val)
        else:
            mean_row.append(None)
    
    # Now extract Mean row values and average across experiences
    # Mean row structure: [Exp1_M, Exp1_F, Exp1_D, Exp2_M, Exp2_F, Exp2_D, Exp3_M, Exp3_F, Exp3_D, Exp4_M, Exp4_F, Exp4_D, BWT_M, BWT_F, BWT_D]
    # Columns: 0, 3, 6, 9 are Model accuracies for Exp 1-4
    # Columns: 1, 4, 7, 10 are Family accuracies for Exp 1-4
    # Columns: 2, 5, 8, 11 are Domain accuracies for Exp 1-4
    # Columns: 12, 13, 14 are BWT values (already averaged)
    
    m_values = [mean_row[0], mean_row[3], mean_row[6], mean_row[9]]  # Exp 1-4 Model accuracies
    f_values = [mean_row[1], mean_row[4], mean_row[7], mean_row[10]]  # Exp 1-4 Family accuracies
    d_values = [mean_row[2], mean_row[5], mean_row[8], mean_row[11]]  # Exp 1-4 Domain accuracies
    
    # Filter out None values
    m_values = [v for v in m_values if v is not None]
    f_values = [v for v in f_values if v is not None]
    d_values = [v for v in d_values if v is not None]
    
    # Calculate averages
    avg_m = round(sum(m_values) / len(m_values), 1) if m_values else None
    avg_f = round(sum(f_values) / len(f_values), 1) if f_values else None
    avg_d = round(sum(d_values) / len(d_values), 1) if d_values else None
    
    # BWT values are already in the Mean row (columns 12-14)
    avg_bwt_m = mean_row[12]
    avg_bwt_f = mean_row[13]
    avg_bwt_d = mean_row[14]
    
    # Format as a single row with all 6 values
    line = f"    Average Summary       & {format_value(avg_m)} & {format_value(avg_f)} & {format_value(avg_d)} & {format_value(avg_bwt_m)} & {format_value(avg_bwt_f)} & {format_value(avg_bwt_d)} \\\\"
    
    return line


def generate_simple_summary_table(results_dir: Path, ablation: str) -> str:
    """
    Generate a simple summary table showing model acc, domain acc, and domain forgetting.
    
    This function computes the Mean row first, then averages those Mean values
    across experiences (Exp 1-4) for Model and Domain accuracies.
    
    Returns a single row with:
    - Average Model accuracy across Exp 1-4 (from Mean row)
    - Average Domain accuracy across Exp 1-4 (from Mean row)
    - Domain BWT (from Mean row, already averaged)
    
    Args:
        results_dir: Root results directory
        ablation: Ablation name
    
    Returns:
        String containing the formatted row
    """
    indexed_configs = {
        "Indexed(Exp 1)": (1, [0]),
        "Indexed(Exp 1-2)": (2, [0, 1]),
        "Indexed(Exp 1-3)": (3, [0, 1, 2]),
        "Indexed(Exp 1-4)": (4, [0, 1, 2, 3]),
    }
    
    table_data = {}
    
    # First, compute all rows including the Mean row (same logic as generate_table)
    for indexed_name, (trained_up_to, test_indices) in indexed_configs.items():
        row_data = []
        
        # For each experience (1-4), get metrics if it's in the test set
        for exp_idx in range(4):
            if exp_idx in test_indices and exp_idx < len(EXPERIENCES):
                test_experience = EXPERIENCES[exp_idx]
                metrics_path = find_metrics_file(results_dir, ablation, test_experience, trained_up_to)
                
                if metrics_path:
                    metrics = load_metrics(metrics_path)
                    if metrics:
                        m, f, d = extract_metrics(metrics)
                        row_data.extend([m, f, d])
                    else:
                        row_data.extend([None, None, None])
                else:
                    row_data.extend([None, None, None])
            else:
                row_data.extend([None, None, None])
        
        # Calculate Forgetting (BWT) for this row
        # BWT = sum over j from 1 to i-1 of (A[i,j] - A[j,j])
        # Where A[i,j] is performance of model trained on exp 1..i+1, tested on exp j+1
        # And A[j,j] is performance of model trained on exp 1..j+1, tested on exp j+1
        # Note: Despite the name "BWT" (Backward Transfer), this actually measures forgetting
        # We use abs() to ensure forgetting is always positive
        bwt_m, bwt_f, bwt_d = None, None, None
        
        if trained_up_to > 1:
            # Calculate backward transfer for experience trained_up_to (1-indexed)
            bwt_m_sum = 0.0
            bwt_f_sum = 0.0
            bwt_d_sum = 0.0
            valid_m = True
            valid_f = True
            valid_d = True
            
            for prev_exp_idx in range(trained_up_to - 1):
                if prev_exp_idx < len(EXPERIENCES):
                    # A[i,j]: row i (trained on exp 1..i+1), tested on exp j+1
                    current_metrics_path = find_metrics_file(results_dir, ablation, EXPERIENCES[prev_exp_idx], trained_up_to)
                    # A[j,j]: row j (trained on exp 1..j+1), tested on exp j+1
                    original_metrics_path = find_metrics_file(results_dir, ablation, EXPERIENCES[prev_exp_idx], prev_exp_idx + 1)
                    
                    if current_metrics_path and original_metrics_path:
                        current_metrics = load_metrics(current_metrics_path)
                        original_metrics = load_metrics(original_metrics_path)
                        
                        if current_metrics and original_metrics:
                            curr_m, curr_f, curr_d = extract_metrics(current_metrics)
                            orig_m, orig_f, orig_d = extract_metrics(original_metrics)
                            
                            if curr_m is not None and orig_m is not None:
                                bwt_m_sum += (curr_m - orig_m)
                            else:
                                valid_m = False
                            if curr_f is not None and orig_f is not None:
                                bwt_f_sum += (curr_f - orig_f)
                            else:
                                valid_f = False
                            if curr_d is not None and orig_d is not None:
                                bwt_d_sum += (curr_d - orig_d)
                            else:
                                valid_d = False
                        else:
                            valid_m = False
                            valid_f = False
                            valid_d = False
                    else:
                        valid_m = False
                        valid_f = False
                        valid_d = False
            
            # Normalize by number of previous experiences
            num_prev_experiences = trained_up_to - 1
            bwt_m = round(abs(bwt_m_sum / num_prev_experiences), 1) if valid_m and num_prev_experiences > 0 else None
            bwt_f = round(abs(bwt_f_sum / num_prev_experiences), 1) if valid_f and num_prev_experiences > 0 else None
            bwt_d = round(abs(bwt_d_sum / num_prev_experiences), 1) if valid_d and num_prev_experiences > 0 else None
        
        row_data.extend([bwt_m, bwt_f, bwt_d])
        table_data[indexed_name] = row_data
    
    # Calculate Mean row
    mean_row = []
    for col_idx in range(15):  # 4 experiences * 3 metrics + 3 BWT = 15 columns
        values = [row[col_idx] for row in table_data.values() if row[col_idx] is not None]
        if values:
            mean_val = round(sum(values) / len(values), 1)
            mean_row.append(mean_val)
        else:
            mean_row.append(None)
    
    # Now extract Mean row values and average across experiences
    # Mean row structure: [Exp1_M, Exp1_F, Exp1_D, Exp2_M, Exp2_F, Exp2_D, Exp3_M, Exp3_F, Exp3_D, Exp4_M, Exp4_F, Exp4_D, BWT_M, BWT_F, BWT_D]
    # Columns: 0, 3, 6, 9 are Model accuracies for Exp 1-4
    # Columns: 2, 5, 8, 11 are Domain accuracies for Exp 1-4
    # Column: 14 is Domain BWT (already averaged)
    
    m_values = [mean_row[0], mean_row[3], mean_row[6], mean_row[9]]  # Exp 1-4 Model accuracies
    d_values = [mean_row[2], mean_row[5], mean_row[8], mean_row[11]]  # Exp 1-4 Domain accuracies
    
    # Filter out None values
    m_values = [v for v in m_values if v is not None]
    d_values = [v for v in d_values if v is not None]
    
    # Calculate averages
    avg_m = round(sum(m_values) / len(m_values), 1) if m_values else None
    avg_d = round(sum(d_values) / len(d_values), 1) if d_values else None
    
    # Domain BWT is already in the Mean row (column 14)
    avg_bwt_d = mean_row[14]
    
    # Format as a single row with 3 values
    line = f"    Simple Summary        & {format_value(avg_m)} & {format_value(avg_d)} & {format_value(avg_bwt_d)} \\\\"
    
    return line


def main():
    results_dir = Path("/home/ubuntu/CCO/results")
    
    # Generate table for each ablation
    for ablation_id, ablation_config in ABLATIONS.items():
        ablation_name = ablation_config["name"]
        display_name = ABLATION_DISPLAY_NAMES.get(ablation_name, ablation_name)
        print(f"\n% {display_name}")
        print(generate_table(results_dir, ablation_name))
        print()
        # Generate average summary table for this ablation
        print(f"\n% {display_name} - Average Summary")
        print(generate_average_summary_table(results_dir, ablation_name))
        print()
        # Generate simple summary table for this ablation
        print(f"\n% {display_name} - Simple Summary")
        print(generate_simple_summary_table(results_dir, ablation_name))
        print()
    
    # Generate joint training table (single row)
    joint_name = "joint-joint_training_upper_bound_1536"
    print(f"\n% Joint Training (Upper Bound)")
    print(generate_joint_training_table(results_dir, joint_name))
    print()


if __name__ == "__main__":
    main()
