
import os
import sys
import argparse
import json
import pickle
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Optional, Union
from tqdm import tqdm
import torch
from datetime import datetime

# Add project root to path
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
    sys.path.append(project_root)

from src.model_fit import ModelRewardEvaluator
from src.custom_rewards import LLMRewardFunction, RewardModelFunction
import src.constants as constants

METHOD_PLOT_NAMES_DICT = {
    'ground_truth': 'Ground-Truth',
    'proposed_method': 'Obj-Disco',
    'baseline_random': 'Iter-Filter',
    'baseline_static': 'One-Shot',
    'baseline_fixed_5': 'Fixed-5',
    'baseline_fixed_15': 'Fixed-15',
    'method_ablation_obj_1': 'Obj-Disco (1 Obj Per Traj)',
    'method_ablation_obj_2': 'Obj-Disco (2 Obj Per Traj)',
    'method_ablation_traj_3': 'Obj-Disco (3 Parallel Traj)',
    'method_ablation_traj_10': 'Obj-Disco (10 Parallel Traj)',
    'method_ablation_traj_25': 'Obj-Disco (25 Parallel Traj)',
    # Add more mappings as needed
}

# Fixed colors for each method
METHOD_COLORS_DICT = {
    'ground_truth': '#2ca02c',      # Green
    'proposed_method': '#d62728',   # Red
    'baseline_random': '#FFD700',   # Gold (faint yellow)
    'baseline_static': '#7f7f7f',   # Gray
    'baseline_fixed_5': '#add8e6',  # Light Blue
    'baseline_fixed_15': '#1f77b4', # Blue
    'method_ablation_obj_1': '#ff7f0e', # Orange
    'method_ablation_obj_2': '#d62728', # Red
    'method_ablation_traj_3': '#d62728',
    'method_ablation_traj_10': '#ff7f0e',
    'method_ablation_traj_25': '#FFD700',
}


def parse_args():
    parser = argparse.ArgumentParser(description='Calculate Model-Fit metrics for model trajectories')
    
    # Model paths
    parser.add_argument('--model_paths_dict', type=str, default='',
                       help='JSON string of method names to model path lists. '
                            'Format for single instance: {"method": [checkpoint1, checkpoint2, ...]}. '
                            'Format for multiple instances: {"method": [[instance1-cp1, instance1-cp2, ...], [instance2-cp1, ...]]}')
    parser.add_argument('--ground_truth_method', type=str, default='ground_truth',
                       help='Name of the ground-truth method in model_paths_dict')
    
    # Dataset parameters
    parser.add_argument('--dataset_name', type=str, default='Anthropic/hh-rlhf',
                       help='Name of evaluation dataset')
    parser.add_argument('--multi_turn', type=bool, default=False, help='Whether the dataset is multi-turn')
    parser.add_argument('--dataset_split', type=str, default='test',
                       help='Dataset split to use')
    parser.add_argument('--num_eval_samples', type=int, default=100,
                       help='Number of evaluation samples')
    
    # Reward function parameters
    parser.add_argument('--reward_function_type', type=str, default='llm',
                       help='Type of reward function: llm or reward_model')
    parser.add_argument('--reward_model_name', type=str, default='meta-llama/Llama-3.1-8B-Instruct',
                       help='Model name for reward scoring')
    
    # LLMRewardFunction specific parameters
    parser.add_argument('--reward_use_api', action='store_true', default=False,
                       help='Use API for reward scoring (LLM only)')
    parser.add_argument('--reward_objectives', type=str, default='["harmlessness", "thoroughness", "clarity"]',
                       help='JSON list of objectives (LLM only)')
    parser.add_argument('--reward_combiner_type', type=str, default='linear',
                       help='Type of reward combiner (LLM only)')
    parser.add_argument('--reward_manual_weights', type=str, default='',
                       help='JSON string of manual weights for objectives (LLM only)')
    parser.add_argument('--reward_manual_bias', type=float, default=0.0,
                       help='Bias term for reward combiner (LLM only)')
    parser.add_argument('--use_detailed_rubric', type=bool, default=False,
                       help='Use detailed rubric for LLM scoring (LLM only)')
    
    # RewardModelFunction specific parameters
    parser.add_argument('--reward_use_quantization', action='store_true', default=False,
                       help='Use 4-bit quantization for reward model (RewardModel only)')
    
    # Generation parameters
    parser.add_argument('--batch_size', type=int, default=8,
                       help='Batch size for generation')
    parser.add_argument('--max_new_tokens', type=int, default=512,
                       help='Maximum new tokens to generate')
    parser.add_argument('--temperature', type=float, default=0.7,
                       help='Sampling temperature')
    parser.add_argument('--top_p', type=float, default=0.9,
                       help='Top-p sampling parameter')
    
    # Evaluation parameters
    parser.add_argument('--num_runs', type=int, default=2,
                       help='Number of evaluation runs for standard error')
    parser.add_argument('--seed', type=int, default=42,
                       help='Base random seed')
    
    # Device and memory
    parser.add_argument('--device', type=str, default='cuda',
                       help='Device to use')
    parser.add_argument('--load_in_4bit', action='store_true', default=True,
                       help='Use 4-bit quantization')
    
    # Save parameters
    parser.add_argument('--save_dir', type=str, default='',
                       help='Directory to save results')
    parser.add_argument('--experiment_name', type=str, default='model_fit_exp',
                       help='Name for this experiment')
    
    # Config file
    parser.add_argument('--config_file', type=str, default=None,
                       help='Path to yaml/json config file')

    # Prompt length filtering
    parser.add_argument('--max_prompt_length', type=int, default=None,
                       help='Maximum prompt length in tokens for filtering')
    parser.add_argument('--base_model_name', type=str, default=None,
                       help='Base model name for tokenizer (required if max_prompt_length is set)')
    
    args = parser.parse_args()
    
    # Load config file if provided
    if args.config_file:
        with open(args.config_file) as f:
            if args.config_file.endswith('.yaml') or args.config_file.endswith('.yml'):
                import yaml
                config = yaml.safe_load(f)
            else:
                config = json.load(f)
        
        # Update args with config values
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)
    
    # Parse JSON strings
    if isinstance(args.model_paths_dict, str) and args.model_paths_dict:
        args.model_paths_dict = json.loads(args.model_paths_dict)
    
    if isinstance(args.reward_objectives, str) and args.reward_objectives:
        args.reward_objectives = json.loads(args.reward_objectives)
    
    if isinstance(args.reward_manual_weights, str) and args.reward_manual_weights:
        args.reward_manual_weights = json.loads(args.reward_manual_weights)
    
    return args

def initialize_reward_function(args):
    """Initialize the ground-truth reward function based on type."""
    print(f"\nInitializing {args.reward_function_type} reward function...")
    
    if args.reward_function_type == 'llm':
        # LLM-based reward with multiple objectives
        reward_function = LLMRewardFunction(
            model_name=args.reward_model_name,
            use_api=args.reward_use_api,
            combiner_type=args.reward_combiner_type,
            manual_weights=args.reward_manual_weights if args.reward_manual_weights else None,
            manual_bias=args.reward_manual_bias,
            device=args.device,
            objective_names=args.reward_objectives,
            dataset_type=constants.DATASET_NAMES_DICT[args.dataset_name],
            use_detailed_rubric=args.use_detailed_rubric,
            normalize_scores=True,
        )
    
    elif args.reward_function_type == 'reward_model':
        # Pre-trained reward model (e.g., from HuggingFace)
        reward_function = RewardModelFunction(
            model_name=args.reward_model_name,
            device=args.device,
            max_length=args.max_new_tokens,
            use_quantization=args.reward_use_quantization,
            init_args=args,
            normalize_scores=True,
        )
    
    else:
        raise ValueError(f"Unknown reward function type: {args.reward_function_type}. "
                        f"Supported types: 'llm', 'reward_model'")
    
    print(f"Initialized: {reward_function.__class__.__name__}")
    return reward_function


def calculate_model_fit_with_runs(
    method_name: str,
    model_paths: List,
    ground_truth_T_rewards_per_run: List[float],
    ground_truth_T_paths: List[str],
    evaluator: ModelRewardEvaluator,
    num_runs: int,
    base_seed: int,
    run_prompts: Optional[Dict[int, List]] = None
) -> Dict:
    """
    Calculate Model-Fit for a method across multiple runs.

    Args:
        method_name: Name of the method
        model_paths: List of model paths. Can be:
            - Simple list: [checkpoint-1, checkpoint-2, ...] for single instance
            - List of lists: [[instance1-cp1, instance1-cp2, ...], [instance2-cp1, ...]] for multiple instances
        ground_truth_T_rewards_per_run: List of ground truth rewards, one per run
        ground_truth_T_paths: List of ground truth model paths used for final rewards
        evaluator: ModelRewardEvaluator instance
        num_runs: Number of evaluation runs per model
        base_seed: Base random seed
        run_prompts: Optional pre-sampled prompts for each run to ensure consistency

    Returns:
        Dictionary with 'model_fit' and 'estimated_rewards' keys,
        each containing timesteps as keys and lists of values
    """
    # Handle both single instance and multiple instances
    if model_paths and isinstance(model_paths[0], str):
        # Single instance: convert to list of lists format
        model_paths_instances = [model_paths]
    else:
        # Multiple instances: already in correct format
        model_paths_instances = model_paths

    num_instances = len(model_paths_instances)
    num_timesteps = len(model_paths_instances[0])

    # Validate that all instances have same number of timesteps
    for instance_paths in model_paths_instances:
        if len(instance_paths) != num_timesteps:
            raise ValueError(f"All instances must have same number of timesteps")

    # Store both model_fit values and raw estimated rewards
    results = {
        'model_fit': {t: [] for t in range(1, num_timesteps + 1)},
        'estimated_rewards': {t: [] for t in range(1, num_timesteps + 1)}
    }

    # Iterate over instances
    for instance_idx, instance_paths in enumerate(model_paths_instances):
        print(f"\n{'='*60}")
        print(f"Method: {method_name} - Instance {instance_idx + 1}/{num_instances}")
        print(f"{'='*60}")

        # Multiple runs per instance
        for run in range(num_runs):
            print(f"\n  Run {run + 1}/{num_runs}")

            # Get pre-sampled prompts for this run if provided
            eval_prompts = run_prompts.get(run) if run_prompts else None

            # Calculate Model-Fit for each timestep
            for t, model_path in enumerate(instance_paths, start=1):
                print(f"    Timestep {t}/{num_timesteps}: {os.path.basename(model_path)}")

                # Check if this model path is the same as the ground truth path for this instance
                # If there are multiple ground truth instances, match by instance index
                ground_truth_path_for_instance = ground_truth_T_paths[instance_idx % len(ground_truth_T_paths)]

                # if os.path.abspath(model_path) == os.path.abspath(ground_truth_path_for_instance):
                #     # This is the exact same model as ground truth, so Model-Fit should be 1.0
                #     model_fit = 1.0
                #     estimated_reward = ground_truth_T_rewards_per_run[run]
                #     print(f"      [GT Match] Using Model-Fit = 1.0 (same as ground truth)")
                # else:
                # Evaluate model with the same prompts for this run
                estimated_reward = evaluator.evaluate_model(model_path, eval_prompts=eval_prompts)

                # Calculate Model-Fit using run-specific ground truth (capped at 1.0)
                ground_truth_for_this_run = ground_truth_T_rewards_per_run[run]
                if ground_truth_for_this_run == 0:
                    model_fit = 0.0
                else:
                    # breakpoint()  # Commented out breakpoint
                    # model_fit = 1 - abs((estimated_reward / ground_truth_for_this_run) - 1)
                    model_fit = (estimated_reward / ground_truth_for_this_run)
                    # Cap Model-Fit at 1.0 (cannot exceed ground truth)
                    # model_fit = min(1.0, model_fit)

                # Store raw estimated reward
                results['estimated_rewards'][t].append(estimated_reward)
                results['model_fit'][t].append(model_fit)

                print(f"      Estimated Reward: {estimated_reward:.4f}, GT: {ground_truth_T_rewards_per_run[run]:.4f}, Model-Fit: {model_fit:.4f}")

    return results


def plot_model_fit_results(
    all_results: Dict[str, Dict],
    save_path: str,
    title: str = "Model-Fit Trajectories"
):
    """
    Plot Model-Fit trajectories with error bars.

    Args:
        all_results: Dict mapping method names to results dict containing 'model_fit' and 'estimated_rewards'
        save_path: Path to save the plot
        title: Plot title (not used anymore, kept for backward compatibility)
    """
    plt.figure(figsize=(12, 8))

    # Collect all values to determine y-axis range
    all_values = []

    for method_name, method_results in all_results.items():
        # Extract model_fit values for plotting (backward compatible)
        if 'model_fit' in method_results:
            results = method_results['model_fit']
        else:
            # Backward compatibility: if old format, use as is
            results = method_results

        timesteps = sorted(results.keys())
        
        # Calculate means and standard errors (cap means at 1.0)
        means = [min(1.0, np.mean(results[t])) for t in timesteps]
        stds = [np.std(results[t]) for t in timesteps]
        n_runs = len(results[timesteps[0]])
        ses = [std / np.sqrt(n_runs) for std in stds]
        
        # Collect all values for y-axis range calculation
        all_values.extend(means)
        
        # Convert timesteps to 1-based model indices for display
        model_indices = list(range(1, len(timesteps) + 1))
        
        # Get color for this method
        color = METHOD_COLORS_DICT.get(method_name, '#1f77b4')  # Default blue if not found

        # Plot with error bars
        plt.errorbar(
            model_indices, means, yerr=ses,
            label=METHOD_PLOT_NAMES_DICT.get(method_name, method_name),
            color=color,
            linewidth=2,
            marker='o',
            markersize=6,
            capsize=5,
            capthick=1.5,
            alpha=0.9
        )
        
        # Fill between for confidence interval (cap upper at 1.0)
        lower = [m - se for m, se in zip(means, ses)]
        upper = [min(1.0, m + se) for m, se in zip(means, ses)]
        plt.fill_between(
            model_indices, lower, upper,
            color=color,
            alpha=0.2
        )

    # X-axis configuration: discrete integers only
    # Handle both new format (with 'model_fit' key) and old format
    max_index = 0
    for method_results in all_results.values():
        if 'model_fit' in method_results:
            max_index = max(max_index, len(method_results['model_fit'].keys()))
        else:
            max_index = max(max_index, len(method_results.keys()))
    model_indices = list(range(1, max_index + 1))
    plt.xticks(model_indices, model_indices)
    
    # Y-axis configuration: lower = max(0.8, min_value - 0.05), upper = 1.05
    min_value = min(all_values) if all_values else 0.8
    y_lower = max(0.4, min_value - 0.05)
    y_upper = 1.05
    plt.ylim(y_lower, y_upper)

    # Set y-axis ticks at increments of 0.05, ending at 1.0
    # Generate ticks from the nearest 0.05 below y_lower to 1.0
    import math
    y_tick_start = math.floor(y_lower * 20) / 20  # Round down to nearest 0.05
    y_ticks = []
    current_tick = y_tick_start
    while current_tick <= 1.0:
        y_ticks.append(current_tick)
        current_tick = round(current_tick + 0.05, 2)  # Round to avoid floating point issues
    plt.yticks(y_ticks)
    
    # Labels (no title)
    plt.xlabel('Model Index', fontsize=14, fontweight='bold')
    plt.ylabel('Model-Fit', fontsize=14, fontweight='bold')
    
    # Grid
    plt.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    
    # Legend
    plt.legend(
        loc='best',
        fontsize=11,
        framealpha=0.95,
        edgecolor='gray',
        title='Method',
        title_fontsize=12
    )
    
    # Tight layout
    plt.tight_layout()
    
    # Save
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\nPlot saved to: {save_path}")
    
    # Also save as PDF for publication quality
    pdf_path = save_path.replace('.png', '.pdf')
    plt.savefig(pdf_path, format='pdf', bbox_inches='tight')
    print(f"PDF saved to: {pdf_path}")
    
    plt.close()


def plot_area_under_model_fit_results(
    all_results: Dict[str, Dict],
    save_path: str,
    ground_truth_method: str = 'ground_truth'
):
    """
    Plot bar chart of Area Under the Model-Fit Curve (AUC) for each method.

    Calculates AUC for each trial using trapezoidal integration, then plots
    mean AUC with standard error bars.

    Args:
        all_results: Dict mapping method names to results dict containing 'model_fit'
        save_path: Path to save the plot
        ground_truth_method: Name of the ground-truth method (placed on far left)
    """
    # Calculate AUC for each method and each trial
    method_aucs = {}

    for method_name, method_results in all_results.items():
        # Extract model_fit values (backward compatible)
        if 'model_fit' in method_results:
            results = method_results['model_fit']
        else:
            results = method_results

        timesteps = sorted(results.keys())
        num_trials = len(results[timesteps[0]])

        # Calculate AUC for each trial
        trial_aucs = []
        for trial_idx in range(num_trials):
            # Get model_fit values for this trial at each timestep (cap at 1.0)
            y_values = [min(1.0, results[t][trial_idx]) for t in timesteps]
            x_values = list(timesteps)

            # Calculate area using trapezoidal rule
            auc = np.trapz(y_values, x_values)
            trial_aucs.append(auc)

        method_aucs[method_name] = trial_aucs

    # Order methods: ground_truth first, then others alphabetically
    method_order = []
    if ground_truth_method in method_aucs:
        method_order.append(ground_truth_method)

    # Add remaining methods (sorted alphabetically for consistency)
    for method in sorted(method_aucs.keys()):
        if method not in method_order:
            method_order.append(method)

    # Calculate means and standard errors
    means = []
    std_errors = []
    colors = []
    labels = []

    for method in method_order:
        aucs = method_aucs[method]
        mean_auc = np.mean(aucs)
        std_auc = np.std(aucs)
        se_auc = std_auc / np.sqrt(len(aucs))

        means.append(mean_auc)
        std_errors.append(se_auc)
        colors.append(METHOD_COLORS_DICT.get(method, '#1f77b4'))
        labels.append(METHOD_PLOT_NAMES_DICT.get(method, method))

    # Normalize by ground-truth AUC (so ground-truth ≈ 1.0, others are relative)
    gt_idx = method_order.index(ground_truth_method) if ground_truth_method in method_order else 0
    gt_mean = means[gt_idx]
    if gt_mean > 0:
        means = [m / gt_mean for m in means]
        std_errors = [se / gt_mean for se in std_errors]

    # Create figure
    fig, ax = plt.subplots(figsize=(10, 7))

    # Create bar positions
    x_pos = np.arange(len(method_order))
    bar_width = 0.6

    # Plot bars with error bars
    bars = ax.bar(
        x_pos, means,
        width=bar_width,
        color=colors,
        edgecolor='black',
        linewidth=1.2,
        alpha=0.85,
        yerr=std_errors,
        capsize=6,
        error_kw={'linewidth': 1.5, 'capthick': 1.5, 'ecolor': 'black'}
    )

    # Add value labels on top of bars
    for i, (bar, mean, se) in enumerate(zip(bars, means, std_errors)):
        height = bar.get_height()
        ax.annotate(
            f'{mean:.2f}',
            xy=(bar.get_x() + bar.get_width() / 2, height + se + 0.02),
            ha='center', va='bottom',
            fontsize=11, fontweight='bold',
            color='black'
        )

    # Configure axes
    ax.set_xticks(x_pos)
    ax.set_xticklabels(labels, fontsize=12, fontweight='medium')

    # Y-axis configuration
    all_values = [m + se for m, se in zip(means, std_errors)]
    y_max = max(all_values) + 0.15 if all_values else 1.0
    y_min = 0
    ax.set_ylim(y_min, y_max)

    # Labels
    ax.set_xlabel('Method', fontsize=14, fontweight='bold')
    ax.set_ylabel('Normalized AUC (relative to Ground-Truth)', fontsize=14, fontweight='bold')

    # Grid (horizontal only for bar charts)
    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5, axis='y')
    ax.set_axisbelow(True)  # Grid behind bars

    # Remove top and right spines for cleaner look
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Tight layout
    plt.tight_layout()

    # Save
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\nAUC bar plot saved to: {save_path}")

    # Also save as PDF for publication quality
    pdf_path = save_path.replace('.png', '.pdf')
    plt.savefig(pdf_path, format='pdf', bbox_inches='tight')
    print(f"AUC PDF saved to: {pdf_path}")

    plt.close()

    # Print AUC summary
    print("\n" + "-"*50)
    print("Normalized AUC Summary (relative to Ground-Truth):")
    print("-"*50)
    for method, mean, se in zip(method_order, means, std_errors):
        display_name = METHOD_PLOT_NAMES_DICT.get(method, method)
        n_trials = len(method_aucs[method])
        print(f"  {display_name}: {mean:.4f} ± {se:.4f} (n={n_trials})")
    print("-"*50)

    return method_aucs


def save_results(results: Dict, save_path: str):
    """Save results to pickle file."""
    with open(save_path, 'wb') as f:
        pickle.dump(results, f)
    print(f"\nResults saved to: {save_path}")


def main():
    args = parse_args()

    # Create save directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    os.makedirs(args.save_dir, exist_ok=True)
    experiment_dir = os.path.join(args.save_dir, f"Model-Fit-{args.experiment_name}-{timestamp}")
    os.makedirs(experiment_dir, exist_ok=True)
    
    # Save config for reproducibility
    config_save_path = os.path.join(experiment_dir, 'config.json')
    with open(config_save_path, 'w') as f:
        json.dump(vars(args), f, indent=2, default=str)
    print(f"Config saved to: {config_save_path}")
    
    # Load dataset and reward function
    # eval_dataset = load_evaluation_dataset(args)
    reward_function = initialize_reward_function(args)
    
    # Validate model paths
    if not args.model_paths_dict:
        raise ValueError("model_paths_dict must be provided")

    if args.ground_truth_method not in args.model_paths_dict:
        raise ValueError(f"Ground-truth method '{args.ground_truth_method}' not found in model_paths_dict")

    # Get number of timesteps (handle both single and multiple instances)
    ground_truth_paths = args.model_paths_dict[args.ground_truth_method]
    if isinstance(ground_truth_paths[0], str):
        # Single instance
        num_timesteps = len(ground_truth_paths)
    else:
        # Multiple instances
        num_timesteps = len(ground_truth_paths[0])

    # Validate all methods have same number of timesteps
    for method, paths in args.model_paths_dict.items():
        if isinstance(paths[0], str):
            # Single instance
            if len(paths) != num_timesteps:
                raise ValueError(f"All methods must have same number of timesteps. "
                               f"{method} has {len(paths)}, expected {num_timesteps}")
        else:
            # Multiple instances
            for instance_idx, instance_paths in enumerate(paths):
                if len(instance_paths) != num_timesteps:
                    raise ValueError(f"All instances must have same number of timesteps. "
                                   f"{method} instance {instance_idx} has {len(instance_paths)}, expected {num_timesteps}")
    
    # Create evaluator
    evaluator = ModelRewardEvaluator(
        # eval_dataset=eval_dataset,
        reward_function=reward_function,
        dataset_name=args.dataset_name,
        dataset_split=args.dataset_split,
        num_eval_samples=args.num_eval_samples,
        batch_size=args.batch_size,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        device=args.device,
        load_in_4bit=args.load_in_4bit,
        seed=args.seed,
        multi_turn=args.multi_turn,
        max_prompt_length=args.max_prompt_length,
        base_model_name=args.base_model_name
    )
    
    # Pre-sample prompts for each run to ensure consistency across all models
    print("\n" + "="*70)
    print("PRE-SAMPLING EVALUATION PROMPTS FOR CONSISTENCY")
    print("="*70)

    run_prompts = {}
    for run in range(args.num_runs):
        print(f"Sampling prompts for run {run + 1}/{args.num_runs}")
        run_prompts[run] = evaluator.sample_eval_prompts()
    print(f"Sampled {args.num_eval_samples} prompts for each of {args.num_runs} runs")

    # First, evaluate ground-truth model(s) at time T
    print("\n" + "="*70)
    print("EVALUATING GROUND-TRUTH MODEL(S) AT TIME T")
    print("="*70)

    ground_truth_paths = args.model_paths_dict[args.ground_truth_method]

    # Handle both single instance and multiple instances
    if isinstance(ground_truth_paths[0], str):
        # Single instance
        ground_truth_T_paths = [ground_truth_paths[-1]]
    else:
        # Multiple instances: get last checkpoint from each instance
        ground_truth_T_paths = [instance[-1] for instance in ground_truth_paths]

    # Collect ground truth rewards organized by run
    ground_truth_T_rewards_per_run = []

    for run in range(args.num_runs):
        print(f"\nEvaluating ground-truth for Run {run + 1}/{args.num_runs}")
        run_rewards = []

        # Evaluate all ground truth instances for this run
        for path_idx, ground_truth_T_path in enumerate(ground_truth_T_paths):
            print(f"  Instance {path_idx + 1}/{len(ground_truth_T_paths)}: {os.path.basename(ground_truth_T_path)}")
            # Use pre-sampled prompts for this run
            reward = evaluator.evaluate_model(ground_truth_T_path, eval_prompts=run_prompts[run])
            run_rewards.append(reward)
            print(f"    Reward: {reward:.4f}")

        # Average across instances for this run
        # breakpoint()  # Commented out breakpoint
        run_average = float(np.mean(run_rewards))
        ground_truth_T_rewards_per_run.append(run_average)
        print(f"  Run {run + 1} average: {run_average:.4f}")

    # Calculate overall statistics
    ground_truth_T_mean = np.mean(ground_truth_T_rewards_per_run)
    ground_truth_T_std = np.std(ground_truth_T_rewards_per_run)
    ground_truth_T_se = ground_truth_T_std / np.sqrt(len(ground_truth_T_rewards_per_run))

    print("\n" + "-"*40)
    print("Ground-truth T rewards per run:")
    for run, reward in enumerate(ground_truth_T_rewards_per_run):
        print(f"  Run {run + 1}: {reward:.4f}")
    print("-"*40)
    print(f"Overall average: {ground_truth_T_mean:.4f} "
          f"(±{ground_truth_T_se:.4f}, std={ground_truth_T_std:.4f}, n={len(ground_truth_T_rewards_per_run)})")
    
    # Calculate Model-Fit for all methods
    all_results = {}
    
    for method_name, model_paths in args.model_paths_dict.items():
        print(f"\n{'='*70}")
        print(f"EVALUATING METHOD: {method_name}")
        print(f"{'='*70}")
        results = calculate_model_fit_with_runs(
            method_name=method_name,
            model_paths=model_paths,
            ground_truth_T_rewards_per_run=ground_truth_T_rewards_per_run,
            ground_truth_T_paths=ground_truth_T_paths,  # Pass the ground truth paths
            evaluator=evaluator,
            num_runs=args.num_runs,
            base_seed=args.seed,
            run_prompts=run_prompts  # Pass the pre-sampled prompts
        )
        
        all_results[method_name] = results
    
    # Save raw results with both model_fit and estimated_rewards
    results_dict = {
        'model_fit_results': all_results,  # Now contains both 'model_fit' and 'estimated_rewards'
        'ground_truth_T_rewards_per_run': ground_truth_T_rewards_per_run,  # List of per-run ground truth rewards
        'ground_truth_T_mean': ground_truth_T_mean,  # Overall average
        'ground_truth_T_paths': ground_truth_T_paths,  # List of ground truth model paths used
        'args': vars(args)
    }
    
    pkl_path = os.path.join(experiment_dir, 'model_fit_results.pkl')
    save_results(results_dict, pkl_path)
    
    # Plot results
    plot_path = os.path.join(experiment_dir, 'model_fit_plot.png')
    plot_model_fit_results(
        all_results,
        plot_path,
        title=f"Model-Fit Trajectories ({args.experiment_name})"
    )

    # Plot Area Under Model-Fit Curve bar chart
    auc_plot_path = os.path.join(experiment_dir, 'model_fit_auc_plot.png')
    plot_area_under_model_fit_results(
        all_results,
        auc_plot_path,
        ground_truth_method=args.ground_truth_method
    )

    # Print summary statistics
    print("\n" + "="*70)
    print("SUMMARY STATISTICS")
    print("="*70)

    for method_name, method_results in all_results.items():
        print(f"\n{method_name}:")

        # Extract model_fit for summary
        if 'model_fit' in method_results:
            model_fit_results = method_results['model_fit']
            estimated_rewards_results = method_results['estimated_rewards']
        else:
            # Backward compatibility
            model_fit_results = method_results
            estimated_rewards_results = None

        for t in sorted(model_fit_results.keys()):
            model_fit_values = model_fit_results[t]
            mean_fit = np.mean(model_fit_values)
            std_fit = np.std(model_fit_values)
            se_fit = std_fit / np.sqrt(len(model_fit_values))

            print(f" Model fit values at timestep {t}: {model_fit_values}")

            if estimated_rewards_results:
                reward_values = estimated_rewards_results[t]
                mean_reward = np.mean(reward_values)
                std_reward = np.std(reward_values)
                se_reward = std_reward / np.sqrt(len(reward_values))
                print(f"  Timestep {t}: Model-Fit={mean_fit:.4f}±{se_fit:.4f}, Reward={mean_reward:.4f}±{se_reward:.4f} (n={len(model_fit_values)})")
            else:
                print(f"  Timestep {t}: {mean_fit:.4f} ± {se_fit:.4f} (std={std_fit:.4f}, n={len(model_fit_values)})")
    
    print(f"\nAll results saved to: {experiment_dir}")


if __name__ == '__main__':
    main()