"""
Main runner for ablation study comparing 4 factuality methods.

Usage:
    python3.10 src/ablation/comparison_runner.py config/ablation/method_comparison.json
"""

import sys
import os

# Add parent directory to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

import json
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Any
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from src.ablation.methods import (
    DifferentiableCoherent,
    HardBaseline,
    HashimotoIndependent,
    BoostedIndependent,
    XGBoostAccuracy
)
from src.reasonining_graph_dataset import Reasoning_Graph_Dataset
from src.utilities import set_seed, generate_kfold_splits, split_dataset


class ComparisonRunner:
    """Runs comparison of all methods."""

    def __init__(self, config_path: str):
        """Load configuration."""
        with open(config_path, 'r') as f:
            self.config = json.load(f)

        self.dataset = None
        self.results = {
            'differentiable_coherent': [],
            'hard_baseline': [],
            'hashimoto_independent': [],
            'boosted_independent': [],
            'xgboost_accuracy': []
        }

    def load_dataset(self):
        """Load dataset once."""
        print(f"Loading dataset: {self.config['dataset_path']}")
        self.dataset = Reasoning_Graph_Dataset(
            self.config['dataset_path'],
            self.config['feature_cols']
        )
        print(f"✓ Loaded {len(self.dataset)} examples")

    def create_method(self, method_name: str):
        """Create method instance from config."""
        method_config = self.config['methods'][method_name]

        if not method_config['enabled']:
            return None

        # Build method-specific config
        config = {
            'C': method_config['hyperparams']['C'],
            'feature_cols': self.config['feature_cols'] if method_config['feature_cols'] == 'all' else method_config['feature_cols'],
            'dataset': self.dataset  # Pass dataset for hard conformal evaluation
        }

        # Add method-specific params
        if method_name == 'differentiable_coherent':
            config['hyperparams'] = method_config['hyperparams']
            config['hyperparams_path'] = method_config.get('hyperparams_path', None)
            return DifferentiableCoherent(config)

        elif method_name == 'hard_baseline':
            config['hyperparams'] = method_config['hyperparams']
            return HardBaseline(config)

        elif method_name == 'hashimoto_independent':
            config['hyperparams'] = method_config['hyperparams']
            return HashimotoIndependent(config)

        elif method_name == 'boosted_independent':
            config['hyperparams'] = method_config['hyperparams']
            config['learning_config'] = method_config.get('learning_config', {})
            return BoostedIndependent(config)

        elif method_name == 'xgboost_accuracy':
            config['hyperparams'] = method_config['hyperparams']
            config['learning_config'] = method_config.get('learning_config', {})
            return XGBoostAccuracy(config)

        else:
            raise ValueError(f"Unknown method: {method_name}")

    def split_dataset(self, seed: int):
        """Split dataset into train/cal/test."""
        rng = np.random.default_rng(seed=seed)
        indices = np.arange(len(self.dataset))
        rng.shuffle(indices)

        protocol = self.config['evaluation_protocol']
        n_total = len(self.dataset)
        n_train = int(n_total * protocol['train_ratio'])
        n_cal_full = n_total - n_train
        n_test = int(n_total * protocol['test_ratio'])
        n_cal = n_cal_full - n_test

        train_idx = indices[:n_train]
        cal_idx = indices[n_train:n_train + n_cal]
        test_idx = indices[n_train + n_cal:]

        return train_idx, cal_idx, test_idx

    def prepare_data_split(self, indices):
        """Prepare X, Y, noise for a data split."""
        X = []
        Y = []
        noise = np.random.normal(0, 1, len(indices))

        for i, idx in enumerate(indices):
            example, labels = self.dataset[idx]
            X.append(example)
            Y.append(labels)

        return X, Y, noise.tolist()

    def run_single_trial(self, method_name: str, method, alpha: float, seed: int):
        """Run single trial for one method and alpha."""
        # Split data
        train_idx, cal_idx, test_idx = self.split_dataset(seed)

        # Prepare splits (including training data for model training)
        X_train, Y_train, noise_train = self.prepare_data_split(train_idx)
        X_cal, Y_cal, noise_cal = self.prepare_data_split(cal_idx)
        X_test, Y_test, noise_test = self.prepare_data_split(test_idx)

        # Create global noise dict for hard conformal evaluation
        noise_dict = {}
        for i, idx in enumerate(train_idx):
            noise_dict[idx] = noise_train[i]
        for i, idx in enumerate(cal_idx):
            noise_dict[idx] = noise_cal[i]
        for i, idx in enumerate(test_idx):
            noise_dict[idx] = noise_test[i]

        try:
            # Calibrate (pass training data, indices, and noise_dict for hard evaluation)
            threshold = method.calibrate(
                X_cal, Y_cal, noise_cal, alpha,
                X_train=X_train, Y_train=Y_train, noise_train=noise_train,
                cal_indices=list(cal_idx), noise_dict=noise_dict
            )

            # Predict (pass alpha and test_indices for hard evaluation)
            predictions = method.predict(X_test, noise_test, threshold, alpha=alpha, test_indices=list(test_idx))

            # Evaluate (pass test indices and dataset for coverage calculation)
            metrics = method.evaluate(predictions, Y_test, test_indices=test_idx, dataset=self.dataset)
            metrics['alpha'] = alpha
            metrics['seed'] = seed

            return metrics

        except Exception as e:
            print(f"ERROR in {method_name} (alpha={alpha}, seed={seed}): {e}")
            import traceback
            traceback.print_exc()
            return None

    def run_kfold_trial(self, method_name: str, method, alpha: float, fold_idx: int, n_folds: int):
        """
        Run single k-fold trial for one method and alpha.

        Uses the same generate_kfold_splits and split_dataset functions as the learned model
        for 100% consistency in the evaluation methodology.
        """
        protocol = self.config['evaluation_protocol']
        train_ratio = protocol['train_ratio']
        cal_ratio = protocol.get('cal_ratio', 0.5)  # cal_ratio within (train+val)

        # Use the same split_dataset function as the learned model
        train_idx, val_idx, test_idx = split_dataset(
            self.dataset,
            train_ratio=train_ratio,
            val_ratio=1.0 - train_ratio,  # val_ratio is the remaining after train
            seed=42,  # Fixed seed for reproducibility across all methods
            fold_idx=fold_idx,
            n_folds=n_folds
        )

        # For boosted_independent and xgboost_accuracy: combine train + val since no early stopping needed
        # This gives it more training data, same as if we didn't hold out validation
        if method_name in ('boosted_independent', 'xgboost_accuracy'):
            # Combine train + val for training, use test for evaluation
            combined_train_idx = list(train_idx) + list(val_idx)
            X_train, Y_train, noise_train = self.prepare_data_split(combined_train_idx)
            # For calibration, we still need a held-out set - use val portion
            X_cal, Y_cal, noise_cal = self.prepare_data_split(val_idx)
            cal_idx = val_idx
        else:
            # Standard split for other methods
            X_train, Y_train, noise_train = self.prepare_data_split(train_idx)
            X_cal, Y_cal, noise_cal = self.prepare_data_split(val_idx)
            cal_idx = val_idx

        X_test, Y_test, noise_test = self.prepare_data_split(test_idx)

        # Create global noise dict for hard conformal evaluation
        noise_dict = {}
        if method_name in ('boosted_independent', 'xgboost_accuracy'):
            # For boosted_independent and xgboost_accuracy, use combined_train_idx for noise mapping
            for i, idx in enumerate(combined_train_idx):
                noise_dict[idx] = noise_train[i]
        else:
            for i, idx in enumerate(train_idx):
                noise_dict[idx] = noise_train[i]
        for i, idx in enumerate(val_idx):
            noise_dict[idx] = noise_cal[i]
        for i, idx in enumerate(test_idx):
            noise_dict[idx] = noise_test[i]

        try:
            # Calibrate (pass training data, indices, and noise_dict for hard evaluation)
            threshold = method.calibrate(
                X_cal, Y_cal, noise_cal, alpha,
                X_train=X_train, Y_train=Y_train, noise_train=noise_train,
                cal_indices=list(cal_idx), noise_dict=noise_dict
            )

            # Predict (pass alpha and test_indices for hard evaluation)
            predictions = method.predict(X_test, noise_test, threshold, alpha=alpha, test_indices=list(test_idx))

            # Evaluate (pass test indices and dataset for coverage calculation)
            metrics = method.evaluate(predictions, Y_test, test_indices=test_idx, dataset=self.dataset)
            metrics['alpha'] = alpha
            metrics['fold_idx'] = fold_idx

            return metrics

        except Exception as e:
            print(f"ERROR in {method_name} (alpha={alpha}, fold={fold_idx}): {e}")
            import traceback
            traceback.print_exc()
            return None

    def load_precomputed_results(self, method_name: str, method_config: Dict) -> List[Dict]:
        """Load precomputed results from hyperparameter tuning."""
        hyperparams_path = method_config.get('hyperparams_path')
        if not hyperparams_path:
            raise ValueError(f"Method {method_name} has use_precomputed_results=True but no hyperparams_path")

        import json
        with open(hyperparams_path, 'r') as f:
            data = json.load(f)

        # Determine which key to use in the results
        if method_name == 'differentiable_coherent':
            results_key = 'learned'
        elif method_name == 'hard_baseline':
            results_key = 'baseline'
        else:
            raise ValueError(f"Precomputed results not available for method {method_name}")

        # Compute average number of claims per example in dataset
        # This is needed to convert absolute claim counts to fractions
        avg_claims_per_example = sum(len(ex['claims']) for ex in self.dataset.raw_data['data']) / len(self.dataset.raw_data['data'])
        avg_true_claims_per_example = sum(
            sum(1 for c in ex['claims'] if int(c['manual_annotation']) == 1)
            for ex in self.dataset.raw_data['data']
        ) / len(self.dataset.raw_data['data'])

        # Extract results for each alpha
        precomputed_results = []
        for alpha in self.config['alphas']:
            alpha_str = str(alpha)
            if alpha_str not in data['results']:
                print(f"  WARNING: No precomputed results for alpha={alpha}, skipping")
                continue

            alpha_data = data['results'][alpha_str][results_key]

            # Convert avg_claims_retained from absolute count to fraction
            # Precomputed results store absolute number of claims (e.g., 2.5 claims)
            # But ablation expects fraction (e.g., 0.34 = 34%)
            avg_claims_absolute = alpha_data['avg_claims_retained']
            avg_retention_fraction = avg_claims_absolute / avg_claims_per_example

            # Compute true_retention from precision
            # true_retention = (true_claims_retained) / (total_true_claims)
            # true_claims_retained = precision * claims_retained
            # So: true_retention = (precision * avg_claims_absolute) / avg_true_claims_per_example
            precision = alpha_data.get('precision', 0.0)
            true_claims_retained = precision * avg_claims_absolute
            true_retention = true_claims_retained / avg_true_claims_per_example if avg_true_claims_per_example > 0 else 0.0

            # Create a single "trial" result with the precomputed metrics
            # Note: These are already averaged across 20 folds, so we treat them as a single trial
            result = {
                'alpha': alpha,
                'seed': -1,  # Indicate this is precomputed
                'coverage': alpha_data['coverage'],
                'marginal_coverage': alpha_data.get('marginal_coverage', 0.0),
                'avg_retention': avg_retention_fraction,  # Now a fraction
                'true_retention': true_retention,  # Computed from precision
                'n_examples': -1,  # Not applicable for precomputed
            }
            precomputed_results.append(result)

        return precomputed_results

    def run_all_trials(self):
        """Run all trials for all methods and alphas."""
        alphas = self.config['alphas']
        default_n_trials = self.config['evaluation_protocol']['n_trials']
        base_seed = self.config['evaluation_protocol']['base_seed']

        for method_name in self.results.keys():
            print(f"\n{'=' * 80}")
            print(f"METHOD: {method_name}")
            print(f"{'=' * 80}")

            method_config = self.config['methods'][method_name]
            if not method_config['enabled']:
                print(f"  Skipped (disabled)")
                continue

            # Check if we should use precomputed results
            if method_config.get('use_precomputed_results', False):
                print(f"  Using precomputed results from hyperparameter tuning")
                precomputed = self.load_precomputed_results(method_name, method_config)
                self.results[method_name].extend(precomputed)
                print(f"  ✓ Loaded {len(precomputed)} precomputed alpha results")
                continue

            # Check if we should use k-fold CV (same as learned model)
            use_kfold = method_config.get('use_kfold_cv', False)

            # Create method
            method = self.create_method(method_name)

            if use_kfold:
                # Use k-fold CV - same methodology as learned model
                n_folds = method_config.get('n_folds', default_n_trials)
                print(f"  Using {n_folds}-fold CV (same as learned model)")

                for alpha in alphas:
                    print(f"\n  Alpha: {alpha}")

                    for fold_idx in tqdm(range(n_folds), desc=f"  Folds ({n_folds})"):
                        metrics = self.run_kfold_trial(method_name, method, alpha, fold_idx, n_folds)

                        if metrics is not None:
                            self.results[method_name].append(metrics)
            else:
                # Use random trial-based evaluation (original behavior)
                n_trials = method_config.get('n_trials', default_n_trials)

                for alpha in alphas:
                    print(f"\n  Alpha: {alpha}")

                    for trial in tqdm(range(n_trials), desc=f"  Trials ({n_trials})"):
                        seed = base_seed + trial

                        metrics = self.run_single_trial(method_name, method, alpha, seed)

                        if metrics is not None:
                            self.results[method_name].append(metrics)

    def analyze_results(self):
        """Analyze and print results."""
        print(f"\n{'=' * 80}")
        print("RESULTS SUMMARY")
        print(f"{'=' * 80}\n")

        for method_name, results in self.results.items():
            if not results:
                continue

            print(f"\n{method_name.upper()}:")
            print("─" * 80)

            df = pd.DataFrame(results)

            # Group by alpha
            for alpha in self.config['alphas']:
                alpha_results = df[df['alpha'] == alpha]

                if len(alpha_results) == 0:
                    continue

                print(f"\n  Alpha = {alpha}:")
                print(f"    Coverage:        {alpha_results['coverage'].mean():.4f} ± {alpha_results['coverage'].std():.4f}")
                print(f"    Marginal Cov:    {alpha_results['marginal_coverage'].mean():.4f} ± {alpha_results['marginal_coverage'].std():.4f}")
                print(f"    Avg Retention:   {alpha_results['avg_retention'].mean():.4f} ± {alpha_results['avg_retention'].std():.4f}")
                print(f"    True Retention:  {alpha_results['true_retention'].mean():.4f} ± {alpha_results['true_retention'].std():.4f}")

    def create_plots(self):
        """Create comparison plots for coverage and retention."""
        output_dir = Path(self.config['output']['dir'])
        output_dir.mkdir(parents=True, exist_ok=True)

        if not self.config['output']['save_plots']:
            return

        # Aggregate results by method and alpha
        method_data = {}
        for method_name, results in self.results.items():
            if not results:
                continue

            df = pd.DataFrame(results)
            method_data[method_name] = df.groupby('alpha').agg({
                'coverage': 'mean',
                'avg_retention': 'mean',
                'true_retention': 'mean'
            }).reset_index()

        if not method_data:
            print("No results to plot")
            return

        # Updated labels
        styles = {
            'differentiable_coherent': {
                'color': '#1f77b4', 'marker': 'o', 'linestyle': '-', 'linewidth': 3,
                'label': 'Differentiable Conformal Factuality', 'zorder': 4
            },
            'hard_baseline': {
                'color': '#ff7f0e', 'marker': 's', 'linestyle': '--', 'linewidth': 2.5,
                'label': 'Conformal Factuality', 'zorder': 3
            },
            'hashimoto_independent': {
                'color': '#2ca02c', 'marker': 'D', 'linestyle': '--', 'linewidth': 2.5,
                'label': 'Independent Factuality', 'zorder': 3
            },
            'boosted_independent': {
                'color': '#d62728', 'marker': '^', 'linestyle': ':', 'linewidth': 1.5,
                'label': 'Boosted Independent', 'zorder': 2
            },
            'xgboost_accuracy': {
                'color': '#9467bd', 'marker': 'v', 'linestyle': '-.', 'linewidth': 2,
                'label': 'XGBoost + Conformal', 'zorder': 2
            },
        }

        # ========================================
        # 1. Combined plot (PNG) - for quick viewing
        # ========================================
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

        for method_name, df in method_data.items():
            style = styles.get(method_name, {'color': '#7f7f7f', 'marker': '*', 'linestyle': ':', 'linewidth': 1.5, 'label': method_name, 'zorder': 1})
            alphas = df['alpha'].values
            coverage = df['coverage'].values
            retention = df['avg_retention'].values

            ax1.plot(alphas, coverage, **style)
            ax2.plot(alphas, retention, **style)

        # Target coverage line
        alphas = sorted(self.config['alphas'])
        target_coverage = [1 - alpha for alpha in alphas]
        ax1.plot(alphas, target_coverage, 'k--', label='Target Coverage (1-α)', linewidth=2, zorder=1)

        # Configure coverage plot
        ax1.set_xlabel("Alpha (Miscoverage Level)", fontsize=12)
        ax1.set_ylabel("Coverage", fontsize=12)
        ax1.set_title("Coverage vs Alpha", fontsize=14, fontweight='bold')
        ax1.legend(fontsize=9, loc='best')
        ax1.grid(True, alpha=0.3)
        ax1.set_ylim([0.7, 1.0])
        ax1.set_xticks(np.arange(min(alphas), max(alphas) + 0.005, 0.02))

        # Configure retention plot
        ax2.set_xlabel("Alpha (Miscoverage Level)", fontsize=12)
        ax2.set_ylabel("Avg Claims Retained (%)", fontsize=12)
        ax2.set_title("Claims Retained vs Alpha", fontsize=14, fontweight='bold')
        ax2.legend(fontsize=9, loc='best')
        ax2.grid(True, alpha=0.3)
        ax2.set_xticks(np.arange(min(alphas), max(alphas) + 0.005, 0.02))

        plt.tight_layout()
        plt.savefig(output_dir / 'ablation_comparison.png', dpi=300, bbox_inches='tight')
        print(f"\n✓ Saved: {output_dir / 'ablation_comparison.png'}")
        plt.close()

        # ========================================
        # 2. Combined side-by-side plot (PDF for LaTeX/ICML)
        #    Retention on left, Coverage on right, vertical legend on right
        # ========================================
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), sharey=False)

        # Plot all methods on both axes (ax1=retention, ax2=coverage)
        lines = []
        labels = []
        for method_name, df in method_data.items():
            style = styles.get(method_name, {'color': '#7f7f7f', 'marker': '*', 'linestyle': ':', 'linewidth': 1.5, 'label': method_name, 'zorder': 1})
            alphas_data = df['alpha'].values
            coverage = df['coverage'].values
            retention = df['avg_retention'].values

            ax1.plot(alphas_data, retention, **style)
            line, = ax2.plot(alphas_data, coverage, **style)
            lines.append(line)
            labels.append(style['label'])

        # Target coverage line (only on coverage plot)
        alphas = sorted(self.config['alphas'])
        target_coverage = [1 - alpha for alpha in alphas]
        target_line, = ax2.plot(alphas, target_coverage, 'k--', linewidth=2, zorder=1)
        lines.append(target_line)
        labels.append('Target (1-α)')

        # Configure retention plot (left)
        ax1.set_xlabel("Miscoverage Level (α)", fontsize=14, fontweight='bold')
        ax1.set_ylabel("Claims Retained (%)", fontsize=14, fontweight='bold')
        ax1.tick_params(axis='both', which='major', labelsize=12)
        ax1.set_xticks(np.arange(0.02, max(alphas) + 0.005, 0.02))
        ax1.grid(True, alpha=0.3, linestyle='--')

        # Configure coverage plot (right)
        ax2.set_xlabel("Miscoverage Level (α)", fontsize=14, fontweight='bold')
        ax2.set_ylabel("Coverage", fontsize=14, fontweight='bold')
        ax2.tick_params(axis='both', which='major', labelsize=12)
        ax2.set_xticks(np.arange(0.02, max(alphas) + 0.005, 0.02))
        ax2.grid(True, alpha=0.3, linestyle='--')
        ax2.set_ylim([0.7, 1.0])

        # Legend in coverage subplot (bottom left)
        ax2.legend(lines, labels, loc='lower left', fontsize=10,
                   framealpha=0.95, edgecolor='black')

        plt.tight_layout()
        plt.savefig(output_dir / 'math_combined.pdf', bbox_inches='tight', pad_inches=0.05)
        print(f"✓ Saved: {output_dir / 'math_combined.pdf'}")
        plt.close()

        # ========================================
        # 3. Separate coverage plot (PDF for LaTeX/ICML)
        # ========================================
        fig, ax = plt.subplots(1, 1, figsize=(5, 4))

        for method_name, df in method_data.items():
            style = styles.get(method_name, {'color': '#7f7f7f', 'marker': '*', 'linestyle': ':', 'linewidth': 1.5, 'label': method_name, 'zorder': 1})
            alphas_data = df['alpha'].values
            coverage = df['coverage'].values
            ax.plot(alphas_data, coverage, **style)

        # Target coverage line
        alphas = sorted(self.config['alphas'])
        target_coverage = [1 - alpha for alpha in alphas]
        ax.plot(alphas, target_coverage, 'k--', label='Target (1-α)', linewidth=2, zorder=1)

        # Configure (NO TITLE for LaTeX) - ICML 2-column format
        ax.set_xlabel("Miscoverage Level (α)", fontsize=14, fontweight='bold')
        ax.set_ylabel("Coverage", fontsize=14, fontweight='bold')
        ax.tick_params(axis='both', which='major', labelsize=12)
        ax.set_xticks(np.arange(0.02, max(alphas) + 0.005, 0.02))
        ax.legend(fontsize=10, loc='best', framealpha=0.95, edgecolor='black')
        ax.grid(True, alpha=0.3, linestyle='--')
        ax.set_ylim([0.7, 1.0])

        plt.tight_layout()
        plt.savefig(output_dir / 'coverage_vs_alpha.pdf', bbox_inches='tight', pad_inches=0.05)
        print(f"✓ Saved: {output_dir / 'coverage_vs_alpha.pdf'}")
        plt.close()

        # ========================================
        # 4. Separate retention plot (PDF for LaTeX/ICML)
        # ========================================
        fig, ax = plt.subplots(1, 1, figsize=(5, 4))

        for method_name, df in method_data.items():
            style = styles.get(method_name, {'color': '#7f7f7f', 'marker': '*', 'linestyle': ':', 'linewidth': 1.5, 'label': method_name, 'zorder': 1})
            alphas_data = df['alpha'].values
            retention = df['avg_retention'].values
            ax.plot(alphas_data, retention, **style)

        # Configure (NO TITLE for LaTeX) - ICML 2-column format
        ax.set_xlabel("Miscoverage Level (α)", fontsize=14, fontweight='bold')
        ax.set_ylabel("Claims Retained (%)", fontsize=14, fontweight='bold')
        ax.tick_params(axis='both', which='major', labelsize=12)
        ax.set_xticks(np.arange(0.02, max(alphas) + 0.005, 0.02))
        ax.legend(fontsize=10, loc='best', framealpha=0.95, edgecolor='black')
        ax.grid(True, alpha=0.3, linestyle='--')

        plt.tight_layout()
        plt.savefig(output_dir / 'retention_vs_alpha.pdf', bbox_inches='tight', pad_inches=0.05)
        print(f"✓ Saved: {output_dir / 'retention_vs_alpha.pdf'}")

    def save_results(self):
        """Save results to JSON."""
        output_dir = Path(self.config['output']['dir'])
        output_dir.mkdir(parents=True, exist_ok=True)

        output_file = output_dir / 'comparison_results.json'

        with open(output_file, 'w') as f:
            json.dump(self.results, f, indent=2)

        print(f"\n✓ Saved results: {output_file}")

    def run(self):
        """Run full comparison."""
        print("\n" + "=" * 80)
        print("FACTUALITY METHODS COMPARISON")
        print("=" * 80 + "\n")

        # Load dataset
        self.load_dataset()

        # Run all trials
        self.run_all_trials()

        # Analyze
        self.analyze_results()

        # Plot
        self.create_plots()

        # Save
        self.save_results()

        print("\n" + "=" * 80)
        print("✓ COMPARISON COMPLETE")
        print("=" * 80 + "\n")


def main():
    """Main entry point."""
    import argparse

    parser = argparse.ArgumentParser(description="Run factuality methods comparison")
    parser.add_argument('config', help='Path to config JSON')
    parser.add_argument('--replot', help='Path to existing comparison_results.json to replot')

    args = parser.parse_args()

    runner = ComparisonRunner(args.config)

    if args.replot:
        # Load existing results and replot
        print(f"Loading results from: {args.replot}")
        with open(args.replot, 'r') as f:
            runner.results = json.load(f)
        runner.create_plots()
        print("✓ Plots regenerated from existing results")
    else:
        runner.run()


if __name__ == "__main__":
    main()
