"""
Unified Evaluator - Supports full test set evaluation and detailed metric statistics

Features:
1. Single trajectory evaluation (onestep + rollout) - with complete visualization
2. Batch test set evaluation (efficient)
3. Conservation error statistics
4. Out-of-bounds rate/magnitude statistics
5. Rollout error vs Horizon
6. Complete visualization output
7. .dat file export
"""

import os
import h5py
import torch
import numpy as np
import joblib
import matplotlib
matplotlib.use('Agg') # Force headless mode, generate files only without popup windows
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, asdict, field
from tqdm import tqdm


@dataclass
class EvaluationMetrics:
    """Evaluation metrics for a single trajectory"""
    # Basic errors
    mae: List[float] = field(default_factory=list)
    rmse: List[float] = field(default_factory=list)

    # Conservation errors
    conservation_drift: List[float] = field(default_factory=list)  # Relative drift
    conservation_absolute: List[float] = field(default_factory=list)  # Absolute drift
    total_mass: List[float] = field(default_factory=list)  # Predicted total mass
    true_mass: List[float] = field(default_factory=list)  # True total mass

    # Out-of-bounds statistics
    violation_rate_lower: List[float] = field(default_factory=list)
    violation_rate_upper: List[float] = field(default_factory=list)
    violation_magnitude_lower: List[float] = field(default_factory=list)
    violation_magnitude_upper: List[float] = field(default_factory=list)

    # Conditional Mean OOB Magnitude - only counts out-of-bounds points
    cond_magnitude_lower: List[float] = field(default_factory=list)
    cond_magnitude_upper: List[float] = field(default_factory=list)

    # Extreme values (predicted)
    min_values: List[float] = field(default_factory=list)
    max_values: List[float] = field(default_factory=list)

    # Extreme values (true field)
    true_min_values: List[float] = field(default_factory=list)
    true_max_values: List[float] = field(default_factory=list)

    # Shallow water three-field errors (optional)
    mae_h: List[float] = field(default_factory=list)
    mae_mx: List[float] = field(default_factory=list)
    mae_my: List[float] = field(default_factory=list)
    cons_drift_h: List[float] = field(default_factory=list)
    cons_drift_mx: List[float] = field(default_factory=list)
    cons_drift_my: List[float] = field(default_factory=list)

    # Shallow water h-field out-of-bounds statistics (optional)
    h_violation_rate_lower: List[float] = field(default_factory=list)
    h_cond_magnitude_lower: List[float] = field(default_factory=list)

    def to_dict(self):
        return asdict(self)


def compute_conservation_error(pred_mass: float, initial_mass: float, eps: float = 1e-8) -> Tuple[float, float]:
    """Calculate conservation error"""
    absolute_drift = abs(pred_mass - initial_mass)
    relative_drift = absolute_drift / (abs(initial_mass) + eps)
    return relative_drift, absolute_drift


def compute_bound_violation(field: np.ndarray, lower_bound: Optional[float] = None,
                            upper_bound: Optional[float] = None, eps: float = 1e-6) -> Dict[str, float]:
    """
    Calculate out-of-bounds statistics

    Returns:
        rate_lower/upper: Out-of-bounds rate (%)
        magnitude_lower/upper: Mean OOB magnitude over entire field (includes non-violating points with value 0)
        cond_magnitude_lower/upper: Conditional Mean OOB Magnitude (only counts out-of-bounds points)
        min_value, max_value: Field extrema
    """
    total_points = field.size
    stats = {
        'rate_lower': 0.0,
        'rate_upper': 0.0,
        'magnitude_lower': 0.0,
        'magnitude_upper': 0.0,
        'cond_magnitude_lower': 0.0,  # Conditional Mean OOB Magnitude
        'cond_magnitude_upper': 0.0,
        'min_value': float(field.min()),
        'max_value': float(field.max()),
    }

    if lower_bound is not None:
        violations_lower = field < (lower_bound - eps)
        num_violations = np.sum(violations_lower)
        stats['rate_lower'] = num_violations / total_points * 100
        if np.any(violations_lower):
            # Mean OOB magnitude over entire field
            stats['magnitude_lower'] = np.mean(np.maximum(0, lower_bound - field))
            # Conditional mean OOB magnitude (only counts out-of-bounds points)
            stats['cond_magnitude_lower'] = np.mean(lower_bound - field[violations_lower])

    if upper_bound is not None:
        violations_upper = field > (upper_bound + eps)
        num_violations = np.sum(violations_upper)
        stats['rate_upper'] = num_violations / total_points * 100
        if np.any(violations_upper):
            # Mean OOB magnitude over entire field
            stats['magnitude_upper'] = np.mean(np.maximum(0, field - upper_bound))
            # Conditional mean OOB magnitude (only counts out-of-bounds points)
            stats['cond_magnitude_upper'] = np.mean(field[violations_upper] - upper_bound)

    return stats


class UnifiedEvaluator:
    """Unified Evaluator"""

    def __init__(
        self,
        model: torch.nn.Module,
        device: torch.device,
        dataset_type: str,
        ndt: int = 1,
        lower_bound: Optional[float] = None,
        upper_bound: Optional[float] = None
    ):
        self.model = model.to(device)
        self.model.eval()
        self.device = device
        self.dataset_type = dataset_type
        self.ndt = ndt
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound

        # Get model output format
        model_name = model.__class__.__name__
        if 'FluxNet_D' in model_name:
            self.output_format = 'fluxnet_d'
        elif 'FluxNet_SW_2D' == model_name:
            self.output_format = 'fluxnet_sw'
        elif 'FluxNet_SW_Baseline' == model_name:
            self.output_format = 'sw_baseline'
        elif 'FluxNet' in model_name:
            self.output_format = 'fluxnet_ul'
        else:
            self.output_format = 'cnn'

    def _load_h5_data(self, h5_file_path: str) -> Dict:
        """Load h5 data according to dataset type, using ndt for sampling"""
        with h5py.File(h5_file_path, 'r') as f:
            if self.dataset_type == 'convection_diffusion':
                return {
                    'c': f['c'][::self.ndt].astype(np.float32),
                    'u': f['u'][:].astype(np.float32)
                }
            elif self.dataset_type == 'traffic_flow':
                return {
                    'rho': f['rho'][::self.ndt].astype(np.float32),
                    'vmax': f['vmax'][:].astype(np.float32)
                }
            elif self.dataset_type == 'shallow_water':
                return {
                    'h': f['h'][::self.ndt].astype(np.float32),
                    'mx': f['mx'][::self.ndt].astype(np.float32),
                    'my': f['my'][::self.ndt].astype(np.float32)
                }
            elif self.dataset_type == 'spinodal_decomposition':
                return {
                    'phi': f['phi_data'][::self.ndt].astype(np.float32)
                }
            else:
                raise ValueError(f"Unknown dataset_type: {self.dataset_type}")

    def _prepare_input(self, data: Dict, t: int) -> torch.Tensor:
        """Prepare model input"""
        if self.dataset_type == 'convection_diffusion':
            input_array = np.stack([data['c'][t], data['u']], axis=0)
        elif self.dataset_type == 'traffic_flow':
            input_array = np.stack([data['rho'][t], data['vmax']], axis=0)
        elif self.dataset_type == 'shallow_water':
            input_array = np.stack([data['h'][t], data['mx'][t], data['my'][t]], axis=0)
        elif self.dataset_type == 'spinodal_decomposition':
            input_array = data['phi'][t][np.newaxis, :]
        else:
            raise ValueError(f"Unknown dataset_type: {self.dataset_type}")

        return torch.from_numpy(input_array).unsqueeze(0).to(self.device)

    def _get_conserved_field(self, data: Dict, t: int) -> np.ndarray:
        """Get conserved field"""
        if self.dataset_type == 'convection_diffusion':
            return data['c'][t]
        elif self.dataset_type == 'traffic_flow':
            return data['rho'][t]
        elif self.dataset_type == 'shallow_water':
            return np.stack([data['h'][t], data['mx'][t], data['my'][t]], axis=0)
        elif self.dataset_type == 'spinodal_decomposition':
            return data['phi'][t]
        else:
            raise ValueError(f"Unknown dataset_type: {self.dataset_type}")

    def _get_prediction(self, model_output) -> np.ndarray:
        """Get prediction from model output"""
        if self.output_format == 'fluxnet_d':
            pred = model_output[0]
        elif self.output_format in ['fluxnet_sw', 'sw_baseline']:
            pred = model_output[0]
        elif self.output_format == 'fluxnet_ul':
            pred = model_output[0]
        else:
            pred = model_output[0]

        return pred.squeeze().cpu().numpy()

    def evaluate_single_trajectory_onestep(
        self,
        h5_file_path: str,
        save_dir: Optional[str] = None,
        visualize: bool = True
    ) -> EvaluationMetrics:
        """
        Onestep evaluation of a single trajectory

        Each step uses ground truth as input for prediction
        """
        data = self._load_h5_data(h5_file_path)
        metrics = EvaluationMetrics()

        # Get number of time steps
        if self.dataset_type == 'shallow_water':
            num_steps = data['h'].shape[0] - 1
        elif self.dataset_type == 'convection_diffusion':
            num_steps = data['c'].shape[0] - 1
        elif self.dataset_type == 'traffic_flow':
            num_steps = data['rho'].shape[0] - 1
        else:
            num_steps = data['phi'].shape[0] - 1

        # Store all predictions and targets for visualization
        all_preds = []
        all_targets = []

        # Initial total mass
        initial_field = self._get_conserved_field(data, 0)
        if len(initial_field.shape) > 2:  # Multi-channel
            initial_mass = initial_field[0].sum()
        else:
            initial_mass = initial_field.sum()

        # Step-by-step evaluation
        for t in range(num_steps):
            input_tensor = self._prepare_input(data, t)
            target = self._get_conserved_field(data, t + 1)

            with torch.no_grad():
                output = self.model(input_tensor)
                pred = self._get_prediction(output)

            all_preds.append(pred.copy())
            all_targets.append(target.copy())

            # Calculate errors
            mae = np.mean(np.abs(pred - target))
            rmse = np.sqrt(np.mean((pred - target) ** 2))
            metrics.mae.append(mae)
            metrics.rmse.append(rmse)

            # Conservation
            if len(pred.shape) > 2:
                pred_mass = pred[0].sum()
                true_mass = target[0].sum()
            else:
                pred_mass = pred.sum()
                true_mass = target.sum()

            rel_drift, abs_drift = compute_conservation_error(pred_mass, initial_mass)
            metrics.conservation_drift.append(rel_drift)
            metrics.conservation_absolute.append(abs_drift)
            metrics.total_mass.append(pred_mass)
            metrics.true_mass.append(true_mass)

            # Out-of-bounds statistics
            if len(pred.shape) > 2:
                pred_field = pred[0]
                target_field = target[0]
            else:
                pred_field = pred
                target_field = target

            viol_stats = compute_bound_violation(pred_field, self.lower_bound, self.upper_bound)
            metrics.violation_rate_lower.append(viol_stats['rate_lower'])
            metrics.violation_rate_upper.append(viol_stats['rate_upper'])
            metrics.violation_magnitude_lower.append(viol_stats['magnitude_lower'])
            metrics.violation_magnitude_upper.append(viol_stats['magnitude_upper'])
            metrics.cond_magnitude_lower.append(viol_stats['cond_magnitude_lower'])
            metrics.cond_magnitude_upper.append(viol_stats['cond_magnitude_upper'])
            metrics.min_values.append(viol_stats['min_value'])
            metrics.max_values.append(viol_stats['max_value'])

            # True field extrema
            metrics.true_min_values.append(float(target_field.min()))
            metrics.true_max_values.append(float(target_field.max()))

            # Shallow water three-field separate statistics
            if self.dataset_type == 'shallow_water':
                for i, (name, field_pred, field_true) in enumerate([
                    ('h', pred[0], target[0]),
                    ('mx', pred[1], target[1]),
                    ('my', pred[2], target[2])
                ]):
                    mae_i = np.mean(np.abs(field_pred - field_true))
                    if name == 'h':
                        metrics.mae_h.append(mae_i)
                        rel_drift, _ = compute_conservation_error(field_pred.sum(), initial_field[0].sum())
                        metrics.cons_drift_h.append(rel_drift)
                        # h-field out-of-bounds statistics (lower bound only)
                        h_viol = compute_bound_violation(field_pred, lower_bound=0.0, upper_bound=None)
                        metrics.h_violation_rate_lower.append(h_viol['rate_lower'])
                        metrics.h_cond_magnitude_lower.append(h_viol['cond_magnitude_lower'])
                    elif name == 'mx':
                        metrics.mae_mx.append(mae_i)
                        rel_drift, _ = compute_conservation_error(field_pred.sum(), initial_field[1].sum())
                        metrics.cons_drift_mx.append(rel_drift)
                    elif name == 'my':
                        metrics.mae_my.append(mae_i)
                        rel_drift, _ = compute_conservation_error(field_pred.sum(), initial_field[2].sum())
                        metrics.cons_drift_my.append(rel_drift)

        # Visualization
        if visualize and save_dir:
            os.makedirs(save_dir, exist_ok=True)
            self._visualize_evaluation(
                all_preds, all_targets, metrics, save_dir, 'onestep',
                h5_file_path
            )

        return metrics

    def evaluate_single_trajectory_rollout(
        self,
        h5_file_path: str,
        save_dir: Optional[str] = None,
        visualize: bool = True
    ) -> EvaluationMetrics:
        """
        Rollout evaluation of a single trajectory

        Autoregressive prediction: uses predicted values as input for the next step
        """
        data = self._load_h5_data(h5_file_path)
        metrics = EvaluationMetrics()

        # Get number of time steps
        if self.dataset_type == 'shallow_water':
            num_steps = data['h'].shape[0] - 1
        elif self.dataset_type == 'convection_diffusion':
            num_steps = data['c'].shape[0] - 1
        elif self.dataset_type == 'traffic_flow':
            num_steps = data['rho'].shape[0] - 1
        else:
            num_steps = data['phi'].shape[0] - 1

        all_preds = []
        all_targets = []

        # Initial state
        initial_field = self._get_conserved_field(data, 0)
        if len(initial_field.shape) > 2:
            initial_mass = initial_field[0].sum()
        else:
            initial_mass = initial_field.sum()

        # Current prediction state
        current_pred = initial_field.copy()

        for t in range(num_steps):
            # Prepare input (using predicted values)
            if self.dataset_type == 'convection_diffusion':
                input_array = np.stack([current_pred, data['u']], axis=0)
            elif self.dataset_type == 'traffic_flow':
                input_array = np.stack([current_pred, data['vmax']], axis=0)
            elif self.dataset_type == 'shallow_water':
                input_array = current_pred
            else:  # spinodal
                input_array = current_pred[np.newaxis, :] if len(current_pred.shape) == 2 else current_pred

            input_tensor = torch.from_numpy(input_array).unsqueeze(0).float().to(self.device)

            target = self._get_conserved_field(data, t + 1)

            with torch.no_grad():
                output = self.model(input_tensor)
                pred = self._get_prediction(output)

            all_preds.append(pred.copy())
            all_targets.append(target.copy())

            # Update current prediction
            current_pred = pred.copy()

            # Calculate errors
            mae = np.mean(np.abs(pred - target))
            rmse = np.sqrt(np.mean((pred - target) ** 2))
            metrics.mae.append(mae)
            metrics.rmse.append(rmse)

            # Conservation
            if len(pred.shape) > 2:
                pred_mass = pred[0].sum()
                true_mass = target[0].sum()
            else:
                pred_mass = pred.sum()
                true_mass = target.sum()

            rel_drift, abs_drift = compute_conservation_error(pred_mass, initial_mass)
            metrics.conservation_drift.append(rel_drift)
            metrics.conservation_absolute.append(abs_drift)
            metrics.total_mass.append(pred_mass)
            metrics.true_mass.append(true_mass)

            # Out-of-bounds statistics
            if len(pred.shape) > 2:
                pred_field = pred[0]
                target_field = target[0]
            else:
                pred_field = pred
                target_field = target

            viol_stats = compute_bound_violation(pred_field, self.lower_bound, self.upper_bound)
            metrics.violation_rate_lower.append(viol_stats['rate_lower'])
            metrics.violation_rate_upper.append(viol_stats['rate_upper'])
            metrics.violation_magnitude_lower.append(viol_stats['magnitude_lower'])
            metrics.violation_magnitude_upper.append(viol_stats['magnitude_upper'])
            metrics.cond_magnitude_lower.append(viol_stats['cond_magnitude_lower'])
            metrics.cond_magnitude_upper.append(viol_stats['cond_magnitude_upper'])
            metrics.min_values.append(viol_stats['min_value'])
            metrics.max_values.append(viol_stats['max_value'])

            # True field extrema
            metrics.true_min_values.append(float(target_field.min()))
            metrics.true_max_values.append(float(target_field.max()))

            # Shallow water three-field separate statistics
            if self.dataset_type == 'shallow_water':
                for i, (name, field_pred, field_true) in enumerate([
                    ('h', pred[0], target[0]),
                    ('mx', pred[1], target[1]),
                    ('my', pred[2], target[2])
                ]):
                    mae_i = np.mean(np.abs(field_pred - field_true))
                    if name == 'h':
                        metrics.mae_h.append(mae_i)
                        rel_drift, _ = compute_conservation_error(field_pred.sum(), initial_field[0].sum())
                        metrics.cons_drift_h.append(rel_drift)
                        # h-field out-of-bounds statistics (lower bound only)
                        h_viol = compute_bound_violation(field_pred, lower_bound=0.0, upper_bound=None)
                        metrics.h_violation_rate_lower.append(h_viol['rate_lower'])
                        metrics.h_cond_magnitude_lower.append(h_viol['cond_magnitude_lower'])
                    elif name == 'mx':
                        metrics.mae_mx.append(mae_i)
                        rel_drift, _ = compute_conservation_error(field_pred.sum(), initial_field[1].sum())
                        metrics.cons_drift_mx.append(rel_drift)
                    elif name == 'my':
                        metrics.mae_my.append(mae_i)
                        rel_drift, _ = compute_conservation_error(field_pred.sum(), initial_field[2].sum())
                        metrics.cons_drift_my.append(rel_drift)

        # Visualization
        if visualize and save_dir:
            os.makedirs(save_dir, exist_ok=True)
            self._visualize_evaluation(
                all_preds, all_targets, metrics, save_dir, 'rollout',
                h5_file_path
            )

        return metrics

    def _visualize_evaluation(
        self,
        all_preds: List[np.ndarray],
        all_targets: List[np.ndarray],
        metrics: EvaluationMetrics,
        save_dir: str,
        mode: str,
        h5_file_path: str
    ):
        """Complete visualization"""
        os.makedirs(save_dir, exist_ok=True)
        num_steps = len(all_preds)

        # 1. Error curves
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        # MAE/RMSE
        axes[0, 0].plot(metrics.mae, 'b-', linewidth=2, label='MAE')
        axes[0, 0].plot(metrics.rmse, 'r--', linewidth=2, label='RMSE')
        axes[0, 0].set_xlabel('Time Step')
        axes[0, 0].set_ylabel('Error')
        axes[0, 0].set_title(f'{mode.upper()} Error')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Conservation curves
        axes[0, 1].plot(metrics.total_mass, 'r-', linewidth=2, label='Predicted')
        axes[0, 1].plot(metrics.true_mass, 'b--', linewidth=2, label='True')
        axes[0, 1].set_xlabel('Time Step')
        axes[0, 1].set_ylabel('Total Mass')
        axes[0, 1].set_title('Conservation')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        # Conservation drift
        axes[1, 0].plot(metrics.conservation_drift, 'g-', linewidth=2)
        axes[1, 0].set_xlabel('Time Step')
        axes[1, 0].set_ylabel('Relative Drift')
        axes[1, 0].set_title('Conservation Drift')
        axes[1, 0].set_yscale('log')
        axes[1, 0].grid(True, alpha=0.3)

        # Out-of-bounds rate
        axes[1, 1].plot(metrics.violation_rate_lower, 'b-', linewidth=2, label='Lower')
        axes[1, 1].plot(metrics.violation_rate_upper, 'r-', linewidth=2, label='Upper')
        axes[1, 1].set_xlabel('Time Step')
        axes[1, 1].set_ylabel('Violation Rate (%)')
        axes[1, 1].set_title('Bound Violation')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

        plt.suptitle(f'{mode.upper()} Evaluation - {Path(h5_file_path).stem}')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f'{mode}_metrics.png'), dpi=150)
        plt.close()

        # 2. Extrema curves (including true field)
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.plot(metrics.min_values, 'b-', linewidth=2, label='Pred Min')
        ax.plot(metrics.max_values, 'r-', linewidth=2, label='Pred Max')
        # Add true field min/max curves
        if metrics.true_min_values:
            ax.plot(metrics.true_min_values, 'b--', linewidth=1.5, alpha=0.7, label='True Min')
            ax.plot(metrics.true_max_values, 'r--', linewidth=1.5, alpha=0.7, label='True Max')
        if self.lower_bound is not None:
            ax.axhline(y=self.lower_bound, color='b', linestyle=':', alpha=0.5, label=f'Lower Bound ({self.lower_bound})')
        if self.upper_bound is not None:
            ax.axhline(y=self.upper_bound, color='r', linestyle=':', alpha=0.5, label=f'Upper Bound ({self.upper_bound})')
        ax.set_xlabel('Time Step')
        ax.set_ylabel('Value')
        ax.set_title(f'{mode.upper()} Min/Max Values')
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f'{mode}_minmax.png'), dpi=150)
        plt.close()

        # 3. Field visualization (select key time steps)
        key_steps = [0, num_steps // 4, num_steps // 2, 3 * num_steps // 4, num_steps - 1]
        # key_steps = [s for s in range(num_steps)]
        key_steps = [s for s in key_steps if s < num_steps]

        if self.dataset_type in ['spinodal_decomposition']:
            self._visualize_2d_snapshots(all_preds, all_targets, key_steps, save_dir, mode)
        elif self.dataset_type in ['convection_diffusion', 'traffic_flow']:
            self._visualize_1d_snapshots(all_preds, all_targets, key_steps, save_dir, mode)
        elif self.dataset_type == 'shallow_water':
            self._visualize_sw_snapshots(all_preds, all_targets, key_steps, save_dir, mode)

        # 4. Save data
        plot_data = {
            'metrics': metrics.to_dict(),
            'all_preds': all_preds,
            'all_targets': all_targets,
            'mode': mode,
            'h5_file': h5_file_path
        }
        joblib.dump(plot_data, os.path.join(save_dir, f'{mode}_data.pkl'))

        # 5. Save .dat files (2D fields)
        if self.dataset_type in ['spinodal_decomposition', 'shallow_water']:
            self._save_dat_files(all_preds, all_targets, key_steps, save_dir, mode)

    def _visualize_2d_snapshots(self, all_preds, all_targets, key_steps, save_dir, mode):
        """2D field snapshot visualization"""
        n = len(key_steps)
        fig, axes = plt.subplots(3, n, figsize=(4 * n, 12))

        for i, t in enumerate(key_steps):
            pred = all_preds[t]
            target = all_targets[t]

            if len(pred.shape) > 2:
                pred = pred[0]
                target = target[0]

            error = np.abs(pred - target)
            vmin = min(pred.min(), target.min())
            vmax = max(pred.max(), target.max())

            im0 = axes[0, i].imshow(pred, cmap='viridis', vmin=vmin, vmax=vmax)
            axes[0, i].set_title(f't={t} Pred')
            plt.colorbar(im0, ax=axes[0, i], shrink=0.8)

            im1 = axes[1, i].imshow(target, cmap='viridis', vmin=vmin, vmax=vmax)
            axes[1, i].set_title(f't={t} True')
            plt.colorbar(im1, ax=axes[1, i], shrink=0.8)

            im2 = axes[2, i].imshow(error, cmap='hot')
            axes[2, i].set_title(f't={t} Error')
            plt.colorbar(im2, ax=axes[2, i], shrink=0.8)

        plt.suptitle(f'{mode.upper()} Field Comparison')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f'{mode}_snapshots.png'), dpi=150)
        plt.close()

    def _visualize_1d_snapshots(self, all_preds, all_targets, key_steps, save_dir, mode):
        """1D field snapshot visualization"""
        n = len(key_steps)
        fig, axes = plt.subplots(2, n, figsize=(4 * n, 8))

        # Pre-calculate range of all data
        pred_values = []
        target_values = []
        all_errors = []

        for t in key_steps:
            pred = all_preds[t]
            target = all_targets[t]

            if len(pred.shape) > 1:
                pred = pred[0]
                target = target[0]

            pred_values.extend(pred.flatten())
            target_values.extend(target.flatten())
            all_errors.extend(np.abs(pred - target).flatten())

        # First row y-axis range (global range of predicted and true values)
        all_values = pred_values + target_values
        vmin = np.min(all_values)
        vmax = np.max(all_values)
        v_range = vmax - vmin
        v_margin = v_range * 0.05 if v_range > 0 else 0.1  # Avoid zero range
        ylim_row1 = [vmin - v_margin, vmax + v_margin]

        # Second row y-axis range (global range of errors)
        emin = np.min(all_errors)
        emax = np.max(all_errors)
        e_range = emax - emin
        e_margin = e_range * 0.05 if e_range > 0 else 0.1
        ylim_row2 = [emin - e_margin, emax + e_margin]
        # Plotting loop
        for i, t in enumerate(key_steps):
            pred = all_preds[t]
            target = all_targets[t]
            if len(pred.shape) > 1:
                pred = pred[0]
                target = target[0]
            x = np.arange(len(pred))
            # First row: true vs predicted values
            axes[0, i].plot(x, target, 'b-', linewidth=2, label='True')
            axes[0, i].plot(x, pred, 'r--', linewidth=2, label='Pred')
            axes[0, i].set_title(f't={t}')
            axes[0, i].legend()
            axes[0, i].grid(True, alpha=0.3)
            axes[0, i].set_ylim(ylim_row1)  # Uniform y-axis range
            # Second row: error
            error = np.abs(pred - target)
            axes[1, i].plot(x, error, 'g-', linewidth=2)
            axes[1, i].set_title(f'Error (MAE={error.mean():.2e})')
            axes[1, i].grid(True, alpha=0.3)
            axes[1, i].set_ylim(ylim_row2)  # Uniform y-axis range
        plt.suptitle(f'{mode.upper()} Field Comparison')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f'{mode}_snapshots.png'), dpi=150)
        plt.close()

        # Space-time plot
        pred_stack = np.array([p[0] if len(p.shape) > 1 else p for p in all_preds])
        target_stack = np.array([t[0] if len(t.shape) > 1 else t for t in all_targets])

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        vmin = min(pred_stack.min(), target_stack.min())
        vmax = max(pred_stack.max(), target_stack.max())

        im0 = axes[0].imshow(pred_stack, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
        axes[0].set_title('Predicted')
        axes[0].set_xlabel('Space')
        axes[0].set_ylabel('Time')
        plt.colorbar(im0, ax=axes[0])

        im1 = axes[1].imshow(target_stack, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
        axes[1].set_title('True')
        axes[1].set_xlabel('Space')
        plt.colorbar(im1, ax=axes[1])

        error_stack = np.abs(pred_stack - target_stack)
        im2 = axes[2].imshow(error_stack, aspect='auto', cmap='hot')
        axes[2].set_title('Error')
        axes[2].set_xlabel('Space')
        plt.colorbar(im2, ax=axes[2])

        plt.suptitle(f'{mode.upper()} Space-Time Plot')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f'{mode}_spacetime.png'), dpi=150)
        plt.close()

    def _visualize_sw_snapshots(self, all_preds, all_targets, key_steps, save_dir, mode):
        """Shallow water equation snapshot visualization"""
        field_names = ['h', 'mx', 'my']

        for t in key_steps:
            pred = all_preds[t]
            target = all_targets[t]

            fig, axes = plt.subplots(3, 3, figsize=(12, 12))

            for i, name in enumerate(field_names):
                pred_f = pred[i]
                target_f = target[i]
                error_f = np.abs(pred_f - target_f)

                vmin = min(pred_f.min(), target_f.min())
                vmax = max(pred_f.max(), target_f.max())

                im0 = axes[i, 0].imshow(pred_f, cmap='viridis', vmin=vmin, vmax=vmax)
                axes[i, 0].set_title(f'Pred {name}')
                plt.colorbar(im0, ax=axes[i, 0], shrink=0.8)

                im1 = axes[i, 1].imshow(target_f, cmap='viridis', vmin=vmin, vmax=vmax)
                axes[i, 1].set_title(f'True {name}')
                plt.colorbar(im1, ax=axes[i, 1], shrink=0.8)

                im2 = axes[i, 2].imshow(error_f, cmap='hot')
                axes[i, 2].set_title(f'Error {name}')
                plt.colorbar(im2, ax=axes[i, 2], shrink=0.8)

            plt.suptitle(f'{mode.upper()} t={t}')
            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, f'{mode}_sw_t{t:04d}.png'), dpi=150)
            plt.close()

    def _save_dat_files(self, all_preds, all_targets, key_steps, save_dir, mode):
        """Save .dat files"""
        dat_dir = os.path.join(save_dir, 'dat_files')
        os.makedirs(dat_dir, exist_ok=True)

        for t in key_steps:
            pred = all_preds[t]
            target = all_targets[t]

            if len(pred.shape) > 2:
                # Multi-channel
                for c in range(pred.shape[0]):
                    self._write_dat(pred[c], os.path.join(dat_dir, f'{mode}_pred_c{c}_t{t:04d}.dat'))
                    self._write_dat(target[c], os.path.join(dat_dir, f'{mode}_true_c{c}_t{t:04d}.dat'))
            elif len(pred.shape) == 2:
                self._write_dat(pred, os.path.join(dat_dir, f'{mode}_pred_t{t:04d}.dat'))
                self._write_dat(target, os.path.join(dat_dir, f'{mode}_true_t{t:04d}.dat'))

    def _write_dat(self, field: np.ndarray, filepath: str):
        """Write Tecplot format .dat file"""
        if len(field.shape) != 2:
            return

        I, J = field.shape
        with open(filepath, 'w') as f:
            f.write(f'VARIABLE="x","y","value"\n')
            f.write(f'ZONE t="field", I={I}, J={J}, F=POINT\n')
            for value in field.flatten():
                f.write(f"{value:.6f}\n")

    def evaluate_test_set_batch(
        self,
        test_folder: str,
        output_dir: str,
        batch_size: int = 32,
        mode: str = 'onestep',
        visualize_trajectories: Optional[List[str]] = None
    ) -> Dict:
        """
        Batch evaluation of entire test set

        Args:
            test_folder: Test set directory
            output_dir: Output directory
            batch_size: Batch size
            mode: 'onestep' or 'rollout'
            visualize_trajectories: List of h5 files to visualize (None means visualize first 3)
        """
        os.makedirs(output_dir, exist_ok=True)

        h5_files = sorted(list(Path(test_folder).glob("*.h5")))
        print(f"\nFound {len(h5_files)} test trajectories")

        # Collect all metrics
        all_metrics = []

        # Decide which trajectories need visualization
        if visualize_trajectories is None:
            vis_files = h5_files[:3] if len(h5_files) >= 3 else h5_files
        elif visualize_trajectories == 'all':
            vis_files = h5_files
        else:
            vis_files = [Path(f) for f in visualize_trajectories if Path(f).exists()]

        # Evaluate one by one
        for h5_file in tqdm(h5_files, desc=f"Evaluating test set ({mode})"):
            need_vis = h5_file in vis_files

            if mode == 'onestep':
                metrics = self.evaluate_single_trajectory_onestep(
                    str(h5_file),
                    save_dir=os.path.join(output_dir, 'trajectories', h5_file.stem) if need_vis else None,
                    visualize=need_vis
                )
            else:
                metrics = self.evaluate_single_trajectory_rollout(
                    str(h5_file),
                    save_dir=os.path.join(output_dir, 'trajectories', h5_file.stem) if need_vis else None,
                    visualize=need_vis
                )

            all_metrics.append(metrics)

        # Aggregate statistics
        summary = self._aggregate_metrics(all_metrics)

        # Save summary results
        joblib.dump(summary, os.path.join(output_dir, f'test_set_summary_{mode}.pkl'))

        # Save JSON format summary results (for easy reading by other tools)
        import json
        json_summary = {k: float(v) if isinstance(v, (np.floating, np.integer)) else v
                        for k, v in summary.items() if not isinstance(v, list)}
        with open(os.path.join(output_dir, f'test_set_summary_{mode}.json'), 'w') as f:
            json.dump(json_summary, f, indent=2)

        # Summary visualization (with error bars)
        self._visualize_summary(all_metrics, summary, output_dir, mode)

        # Create summary visualization folder (all_onestep / all_rollout)
        self._create_summary_gallery(h5_files, vis_files, output_dir, mode)

        # Print statistics
        self._print_summary(summary)

        return summary

    def _create_summary_gallery(self, h5_files: List[Path], vis_files: List[Path],
                                 output_dir: str, mode: str):
        """
        Create summary visualization folder, gathering all visualization images together
        """
        gallery_dir = os.path.join(output_dir, f'all_{mode}')
        os.makedirs(gallery_dir, exist_ok=True)

        # Copy key visualization images for each trajectory
        for h5_file in vis_files:
            traj_dir = os.path.join(output_dir, 'trajectories', h5_file.stem)
            if not os.path.exists(traj_dir):
                continue

            # Copy metrics and snapshots images
            for img_name in [f'{mode}_metrics.png', f'{mode}_snapshots.png',
                             f'{mode}_minmax.png', f'{mode}_spacetime.png']:
                src_path = os.path.join(traj_dir, img_name)
                if os.path.exists(src_path):
                    dst_name = f'{h5_file.stem}_{img_name}'
                    import shutil
                    shutil.copy2(src_path, os.path.join(gallery_dir, dst_name))

        # Create combined gallery image (if trajectory count is appropriate)
        if len(vis_files) <= 16 and len(vis_files) > 0:
            self._create_combined_gallery_image(vis_files, output_dir, gallery_dir, mode)

    def _aggregate_metrics(self, all_metrics: List[EvaluationMetrics]) -> Dict:
        """Aggregate metrics from multiple trajectories"""
        # Aggregate by time step
        max_steps = max(len(m.mae) for m in all_metrics)

        # Collect metrics from all trajectories for each time step
        mae_by_step = [[] for _ in range(max_steps)]
        rmse_by_step = [[] for _ in range(max_steps)]
        cons_drift_by_step = [[] for _ in range(max_steps)]
        viol_lower_by_step = [[] for _ in range(max_steps)]
        viol_upper_by_step = [[] for _ in range(max_steps)]
        cond_mag_lower_by_step = [[] for _ in range(max_steps)]
        cond_mag_upper_by_step = [[] for _ in range(max_steps)]

        for m in all_metrics:
            for t, val in enumerate(m.mae):
                mae_by_step[t].append(val)
            for t, val in enumerate(m.rmse):
                rmse_by_step[t].append(val)
            for t, val in enumerate(m.conservation_drift):
                cons_drift_by_step[t].append(val)
            for t, val in enumerate(m.violation_rate_lower):
                viol_lower_by_step[t].append(val)
            for t, val in enumerate(m.violation_rate_upper):
                viol_upper_by_step[t].append(val)
            for t, val in enumerate(m.cond_magnitude_lower):
                cond_mag_lower_by_step[t].append(val)
            for t, val in enumerate(m.cond_magnitude_upper):
                cond_mag_upper_by_step[t].append(val)

        # Calculate mean and standard deviation
        summary = {
            'num_trajectories': len(all_metrics),
            'max_steps': max_steps,

            # Statistics by time step
            'mae_mean_by_step': [np.mean(s) if s else 0 for s in mae_by_step],
            'mae_std_by_step': [np.std(s) if s else 0 for s in mae_by_step],
            'rmse_mean_by_step': [np.mean(s) if s else 0 for s in rmse_by_step],
            'rmse_std_by_step': [np.std(s) if s else 0 for s in rmse_by_step],
            'cons_drift_mean_by_step': [np.mean(s) if s else 0 for s in cons_drift_by_step],
            'cons_drift_std_by_step': [np.std(s) if s else 0 for s in cons_drift_by_step],
            'viol_lower_mean_by_step': [np.mean(s) if s else 0 for s in viol_lower_by_step],
            'viol_upper_mean_by_step': [np.mean(s) if s else 0 for s in viol_upper_by_step],
            'cond_mag_lower_mean_by_step': [np.mean([v for v in s if v > 0]) if any(v > 0 for v in s) else 0 for s in cond_mag_lower_by_step],
            'cond_mag_upper_mean_by_step': [np.mean([v for v in s if v > 0]) if any(v > 0 for v in s) else 0 for s in cond_mag_upper_by_step],

            # Global statistics (time-averaged)
            'mae_overall_mean': np.mean([np.mean(m.mae) for m in all_metrics]),
            'mae_overall_std': np.std([np.mean(m.mae) for m in all_metrics]),
            'rmse_overall_mean': np.mean([np.mean(m.rmse) for m in all_metrics]),
            'rmse_overall_std': np.std([np.mean(m.rmse) for m in all_metrics]),
            'cons_drift_max': max(max(m.conservation_drift) for m in all_metrics),
            'cons_drift_mean': np.mean([np.mean(m.conservation_drift) for m in all_metrics]),
            'cons_drift_std': np.std([np.mean(m.conservation_drift) for m in all_metrics]),
            'viol_lower_mean': np.mean([np.mean(m.violation_rate_lower) for m in all_metrics]),
            'viol_lower_std': np.std([np.mean(m.violation_rate_lower) for m in all_metrics]),
            'viol_upper_mean': np.mean([np.mean(m.violation_rate_upper) for m in all_metrics]),
            'viol_upper_std': np.std([np.mean(m.violation_rate_upper) for m in all_metrics]),
            'min_value_overall': min(min(m.min_values) for m in all_metrics),
            'max_value_overall': max(max(m.max_values) for m in all_metrics),
        }

        # Conditional Mean OOB Magnitude statistics
        # Calculate average OOB magnitude for all out-of-bounds points
        all_cond_lower = [v for m in all_metrics for v in m.cond_magnitude_lower if v > 0]
        all_cond_upper = [v for m in all_metrics for v in m.cond_magnitude_upper if v > 0]
        summary['cond_magnitude_lower_mean'] = np.mean(all_cond_lower) if all_cond_lower else 0.0
        summary['cond_magnitude_lower_std'] = np.std(all_cond_lower) if all_cond_lower else 0.0
        summary['cond_magnitude_upper_mean'] = np.mean(all_cond_upper) if all_cond_upper else 0.0
        summary['cond_magnitude_upper_std'] = np.std(all_cond_upper) if all_cond_upper else 0.0

        # Rollout@T statistics (T = 0.25, 0.5, 0.75, 1.0) - use final moment instead of global average
        for T_ratio in [0.25, 0.5, 0.75, 1.0]:
            T_idx = int(T_ratio * max_steps) - 1
            T_idx = max(0, min(T_idx, max_steps - 1))
            if mae_by_step[T_idx]:
                summary[f'mae_at_T{T_ratio}'] = np.mean(mae_by_step[T_idx])
                summary[f'mae_at_T{T_ratio}_std'] = np.std(mae_by_step[T_idx])
                summary[f'rmse_at_T{T_ratio}'] = np.mean(rmse_by_step[T_idx])
                summary[f'rmse_at_T{T_ratio}_std'] = np.std(rmse_by_step[T_idx])
                summary[f'cons_drift_at_T{T_ratio}'] = np.mean(cons_drift_by_step[T_idx])
                summary[f'cons_drift_at_T{T_ratio}_std'] = np.std(cons_drift_by_step[T_idx])
                summary[f'viol_lower_at_T{T_ratio}'] = np.mean(viol_lower_by_step[T_idx])
                summary[f'viol_lower_at_T{T_ratio}_std'] = np.std(viol_lower_by_step[T_idx])
                summary[f'viol_upper_at_T{T_ratio}'] = np.mean(viol_upper_by_step[T_idx])
                summary[f'viol_upper_at_T{T_ratio}_std'] = np.std(viol_upper_by_step[T_idx])

        # Shallow water three-field separate statistics
        if all_metrics[0].mae_h:
            summary['sw_mae_h_mean'] = np.mean([np.mean(m.mae_h) for m in all_metrics])
            summary['sw_mae_h_std'] = np.std([np.mean(m.mae_h) for m in all_metrics])
            summary['sw_mae_mx_mean'] = np.mean([np.mean(m.mae_mx) for m in all_metrics])
            summary['sw_mae_mx_std'] = np.std([np.mean(m.mae_mx) for m in all_metrics])
            summary['sw_mae_my_mean'] = np.mean([np.mean(m.mae_my) for m in all_metrics])
            summary['sw_mae_my_std'] = np.std([np.mean(m.mae_my) for m in all_metrics])
            summary['sw_cons_drift_h_mean'] = np.mean([np.mean(m.cons_drift_h) for m in all_metrics])
            summary['sw_cons_drift_h_std'] = np.std([np.mean(m.cons_drift_h) for m in all_metrics])
            summary['sw_cons_drift_mx_mean'] = np.mean([np.mean(m.cons_drift_mx) for m in all_metrics])
            summary['sw_cons_drift_mx_std'] = np.std([np.mean(m.cons_drift_mx) for m in all_metrics])
            summary['sw_cons_drift_my_mean'] = np.mean([np.mean(m.cons_drift_my) for m in all_metrics])
            summary['sw_cons_drift_my_std'] = np.std([np.mean(m.cons_drift_my) for m in all_metrics])
            # h-field out-of-bounds statistics
            if all_metrics[0].h_violation_rate_lower:
                summary['sw_h_viol_rate_mean'] = np.mean([np.mean(m.h_violation_rate_lower) for m in all_metrics])
                summary['sw_h_viol_rate_std'] = np.std([np.mean(m.h_violation_rate_lower) for m in all_metrics])
                all_h_cond = [v for m in all_metrics for v in m.h_cond_magnitude_lower if v > 0]
                summary['sw_h_cond_mag_mean'] = np.mean(all_h_cond) if all_h_cond else 0.0
                summary['sw_h_cond_mag_std'] = np.std(all_h_cond) if all_h_cond else 0.0

        return summary

    def _visualize_summary(self, all_metrics: List[EvaluationMetrics], summary: Dict,
                           output_dir: str, mode: str):
        """Summary visualization (with error bars)"""
        steps = range(summary['max_steps'])

        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        # MAE with error bars
        axes[0, 0].errorbar(
            steps,
            summary['mae_mean_by_step'],
            yerr=summary['mae_std_by_step'],
            fmt='-o', capsize=3, markersize=2, linewidth=1
        )
        axes[0, 0].set_xlabel('Time Step')
        axes[0, 0].set_ylabel('MAE')
        axes[0, 0].set_title(f'{mode.upper()} MAE (mean ± std)')
        axes[0, 0].grid(True, alpha=0.3)

        # RMSE with error bars
        axes[0, 1].errorbar(
            steps,
            summary['rmse_mean_by_step'],
            yerr=summary['rmse_std_by_step'],
            fmt='-o', capsize=3, markersize=2, linewidth=1, color='orange'
        )
        axes[0, 1].set_xlabel('Time Step')
        axes[0, 1].set_ylabel('RMSE')
        axes[0, 1].set_title(f'{mode.upper()} RMSE (mean ± std)')
        axes[0, 1].grid(True, alpha=0.3)

        # Conservation drift with error bars
        axes[1, 0].errorbar(
            steps,
            summary['cons_drift_mean_by_step'],
            yerr=summary['cons_drift_std_by_step'],
            fmt='-o', capsize=3, markersize=2, linewidth=1, color='green'
        )
        axes[1, 0].set_xlabel('Time Step')
        axes[1, 0].set_ylabel('Relative Drift')
        axes[1, 0].set_title(f'{mode.upper()} Conservation Drift (mean ± std)')
        axes[1, 0].set_yscale('log')
        axes[1, 0].grid(True, alpha=0.3)

        # Violation rate
        axes[1, 1].plot(steps, summary['viol_lower_mean_by_step'], 'b-', linewidth=2, label='Lower')
        axes[1, 1].plot(steps, summary['viol_upper_mean_by_step'], 'r-', linewidth=2, label='Upper')
        axes[1, 1].set_xlabel('Time Step')
        axes[1, 1].set_ylabel('Violation Rate (%)')
        axes[1, 1].set_title(f'{mode.upper()} Bound Violation (mean)')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

        plt.suptitle(f'Test Set Summary ({summary["num_trajectories"]} trajectories)')
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'{mode}_summary.png'), dpi=150)
        plt.close()

    def _create_combined_gallery_image(self, vis_files: List[Path], output_dir: str,
                                        gallery_dir: str, mode: str):
        """
        Create combined gallery image, aggregating visualizations of all trajectories
        """
        from PIL import Image

        # Collect all metrics images
        images = []
        labels = []
        for h5_file in vis_files:
            img_path = os.path.join(gallery_dir, f'{h5_file.stem}_{mode}_metrics.png')
            if os.path.exists(img_path):
                images.append(img_path)
                labels.append(h5_file.stem)

        if not images:
            return

        # Calculate grid layout
        n = len(images)
        cols = min(4, n)
        rows = (n + cols - 1) // cols

        # Read images and combine
        try:
            img_list = [Image.open(p) for p in images]
            w, h = img_list[0].size

            # Create combined image
            combined = Image.new('RGB', (cols * w, rows * h), color='white')

            for idx, (img, label) in enumerate(zip(img_list, labels)):
                row = idx // cols
                col = idx % cols
                combined.paste(img, (col * w, row * h))

            combined.save(os.path.join(gallery_dir, f'{mode}_combined_gallery.png'))

            # Close images
            for img in img_list:
                img.close()

        except Exception as e:
            print(f"Failed to create combined gallery image: {e}")

    def _print_summary(self, summary: Dict):
        """Print summary statistics"""
        print("\n" + "=" * 60)
        print("Test Set Evaluation Summary")
        print("=" * 60)
        print(f"Number of trajectories: {summary['num_trajectories']}")
        print(f"Maximum time steps: {summary['max_steps']}")
        print(f"\nMAE: {summary['mae_overall_mean']:.6e} ± {summary['mae_overall_std']:.6e}")
        print(f"RMSE: {summary['rmse_overall_mean']:.6e} ± {summary['rmse_overall_std']:.6e}")
        print(f"\nConservation drift: mean={summary['cons_drift_mean']:.6e}, max={summary['cons_drift_max']:.6e}")

        if summary['viol_lower_mean'] > 0 or summary['viol_upper_mean'] > 0:
            print(f"\nOut-of-bounds rate: lower={summary['viol_lower_mean']:.4f}%, upper={summary['viol_upper_mean']:.4f}%")

        print(f"\nValue range: [{summary['min_value_overall']:.6f}, {summary['max_value_overall']:.6f}]")
        print("=" * 60)


def evaluate_model_on_test_set(
    model: torch.nn.Module,
    test_folder: str,
    dataset_type: str,
    output_dir: str,
    ndt: int = 1,
    lower_bound: Optional[float] = None,
    upper_bound: Optional[float] = None,
    device: torch.device = None,
    mode: str = 'both',
    visualize_trajectories: Optional[List[str]] = None
) -> Dict:
    """
    Convenience function: Evaluate model performance on test set

    Args:
        mode: 'onestep', 'rollout', or 'both'
        visualize_trajectories: List of h5 files to visualize

    Returns:
        Summary statistics dictionary
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    evaluator = UnifiedEvaluator(
        model=model,
        device=device,
        dataset_type=dataset_type,
        ndt=ndt,
        lower_bound=lower_bound,
        upper_bound=upper_bound
    )

    results = {}

    if mode in ['onestep', 'both']:
        results['onestep'] = evaluator.evaluate_test_set_batch(
            test_folder,
            os.path.join(output_dir, 'onestep'),
            mode='onestep',
            visualize_trajectories=visualize_trajectories
        )

    if mode in ['rollout', 'both']:
        results['rollout'] = evaluator.evaluate_test_set_batch(
            test_folder,
            os.path.join(output_dir, 'rollout'),
            mode='rollout',
            visualize_trajectories=visualize_trajectories
        )

    return results
