"""
Objective Summary Module

Implements the greedy algorithm from Section 4.3 of Methods.md for selecting
sample trajectories that summarize a discovered objective.

Key concepts:
- Trend Fidelity: fid(ξ) = exp(- Σ_{t=1}^{T} (u_t - f*(t))^2)
  where u_t is the sample's score at timestep t and f*(t) is the global trend value
- f_fid(E) = Σ_{ξ ∈ E} fid(ξ) (total fidelity, modular function)
- Diversity: f_div(S) = sum_{j=1}^m sqrt(|S ∩ P_j|) using K-Means partitions
- Objective function: F(S) = (1 - lambda) * f_fid(S) + lambda * f_div(S) (convex combination)
"""

import os
import json
import glob
import numpy as np
from typing import List, Dict, Tuple, Optional, Callable
from openai import OpenAI
from sklearn.cluster import KMeans

try:
    from .constants import OPENAI_API_KEY
    from .trend_functions import TREND_FUNCTIONS
except ImportError:
    from constants import OPENAI_API_KEY
    from trend_functions import TREND_FUNCTIONS


class TrajectoryData:
    """Container for a single trajectory with its metadata."""

    def __init__(
        self,
        prompt_idx: int,
        prompt: str,
        responses: List[str],
        scores: List[float],
        embedding: Optional[np.ndarray] = None,
        cluster_label: Optional[int] = None,
        global_trend_func: Optional[Callable] = None,
        global_trend_params: Optional[np.ndarray] = None
    ):
        """
        Args:
            prompt_idx: Original index in the dataset
            prompt: The input prompt string
            responses: List of responses from each checkpoint (length T)
            scores: List of objective scores for each response (length T)
            embedding: Optional pre-computed embedding for the prompt+responses
            cluster_label: Cluster assignment from K-Means (0 to m-1)
            global_trend_func: The global trend function f*(t) from PredictableTrendVerifier
            global_trend_params: Parameters for the global trend function
        """
        self.prompt_idx = prompt_idx
        self.prompt = prompt
        self.responses = responses
        self.scores = scores
        self.embedding = embedding
        self.cluster_label = cluster_label
        self.global_trend_func = global_trend_func
        self.global_trend_params = global_trend_params

    @property
    def fidelity(self) -> float:
        """
        Trend fidelity from Section 4.3 of Methods.md.

        fid(ξ) = exp(- Σ_{t=1}^{T} (u_t - f*(t))^2)

        where u_t is the sample's objective score at timestep t
        and f*(t) is the global trend value at timestep t.

        Returns:
            Fidelity score in (0, 1] - higher is better (closer to global trend)
        """
        if self.global_trend_func is None or self.global_trend_params is None:
            # Fall back to legacy salience-based calculation if trend not available
            return self.salience

        if len(self.scores) < 1:
            return 0.0

        # Compute squared L2-error between sample scores and global trend
        t_values = np.arange(1, len(self.scores) + 1)  # t = 1, 2, ..., T
        u_t = np.array(self.scores)  # Sample's objective scores
        f_star_t = self.global_trend_func(t_values, *self.global_trend_params)  # Global trend values

        # Squared error sum
        squared_error_sum = np.sum((u_t - f_star_t) ** 2)

        # Convert to similarity score using exponential
        return np.exp(-squared_error_sum)

    @property
    def salience(self) -> float:
        """
        Legacy trend salience: |u_T - u_1|
        Kept for backward compatibility.
        """
        if len(self.scores) < 2:
            return 0.0
        return abs(self.scores[-1] - self.scores[0])

    @property
    def score_trajectory(self) -> List[float]:
        """Return the full score trajectory."""
        return self.scores


def get_embedding(text: str, model: str = "text-embedding-3-small", client: Optional[OpenAI] = None) -> np.ndarray:
    """
    Get OpenAI embedding for a text.

    Args:
        text: Text to embed
        model: OpenAI embedding model name
        client: Optional OpenAI client (creates one if not provided)

    Returns:
        Numpy array of the embedding
    """
    if client is None:
        client = OpenAI(api_key=OPENAI_API_KEY)

    # Clean the text
    text = text.replace("\n", " ").strip()
    if not text:
        text = " "  # Avoid empty string

    # Truncate if too long (embedding model has token limits)
    if len(text) > 8000:
        text = text[:8000]

    response = client.embeddings.create(input=[text], model=model)
    return np.array(response.data[0].embedding)


def get_embeddings_batch(texts: List[str], model: str = "text-embedding-3-small", batch_size: int = 100) -> List[np.ndarray]:
    """
    Get embeddings for multiple texts efficiently.

    Args:
        texts: List of texts to embed
        model: OpenAI embedding model name
        batch_size: Number of texts to process per API call

    Returns:
        List of numpy arrays (embeddings)
    """
    client = OpenAI(api_key=OPENAI_API_KEY)
    all_embeddings = []

    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]

        # Clean texts
        cleaned_texts = []
        for text in batch_texts:
            text = text.replace("\n", " ").strip()
            if not text:
                text = " "
            if len(text) > 8000:
                text = text[:8000]
            cleaned_texts.append(text)

        response = client.embeddings.create(input=cleaned_texts, model=model)
        batch_embeddings = [np.array(item.embedding) for item in response.data]
        all_embeddings.extend(batch_embeddings)

    return all_embeddings


def load_global_trend_from_results_dir(
    objectives_results_dir: str,
    objective_name: str
) -> Tuple[Optional[Callable], Optional[np.ndarray], Optional[str], Optional[Dict]]:
    """
    Load the global trend function and parameters from an objectives discovery results directory.

    The function searches for the trend_plots_<objective_name>_* subdirectory and loads
    the final_trend_params.json file to get the best trend function and its parameters.

    Args:
        objectives_results_dir: Path to the objectives discovery results directory
            (e.g., '')
        objective_name: Name of the objective to find trend parameters for

    Returns:
        Tuple of:
        - global_trend_func: The trend function f*(t) or None if not found
        - global_trend_params: Parameters for the trend function as numpy array or None
        - best_trend_type: Name of the best trend type (e.g., 'linear') or None
        - trend_params_dict: Full dictionary of trend parameters or None
    """
    if not objectives_results_dir or not os.path.isdir(objectives_results_dir):
        print(f"Warning: objectives_results_dir not found or not a directory: {objectives_results_dir}")
        return None, None, None, None

    # Clean objective name to match directory naming convention
    # The _save_trend_plots method uses: re.sub(r'[^\w\s-]', '', desc)[:50] and re.sub(r'[-\s]+', '_', desc)
    import re
    objective_name_clean = re.sub(r'[^\w\s-]', '', objective_name)[:50]
    objective_name_clean = re.sub(r'[-\s]+', '_', objective_name_clean)

    # Search for trend_plots_<objective_name>_* directories
    pattern = os.path.join(objectives_results_dir, f"trend_plots_{objective_name_clean}_*")
    matching_dirs = glob.glob(pattern)

    if not matching_dirs:
        # Try a more flexible search if exact match fails
        pattern_flexible = os.path.join(objectives_results_dir, f"trend_plots_*{objective_name_clean[:20]}*")
        matching_dirs = glob.glob(pattern_flexible)

    if not matching_dirs:
        # Try searching all trend_plots directories
        all_trend_dirs = glob.glob(os.path.join(objectives_results_dir, "trend_plots_*"))
        print(f"Warning: No exact match for objective '{objective_name}'. Available trend directories:")
        for d in all_trend_dirs:
            print(f"  - {os.path.basename(d)}")
        return None, None, None, None

    # Use the most recent matching directory (by timestamp in name)
    matching_dirs.sort(reverse=True)  # Most recent first
    trend_dir = matching_dirs[0]

    # Load final_trend_params.json
    params_file = os.path.join(trend_dir, "final_trend_params.json")
    if not os.path.exists(params_file):
        print(f"Warning: final_trend_params.json not found in {trend_dir}")
        return None, None, None, None

    with open(params_file, 'r') as f:
        trend_params_dict = json.load(f)

    # Find the best trend type
    best_trend_type = None
    for trend_type, params_info in trend_params_dict.items():
        if params_info.get('is_best', False):
            best_trend_type = trend_type
            break

    if best_trend_type is None:
        # Fall back to the one with lowest avg_prediction_error
        best_error = float('inf')
        for trend_type, params_info in trend_params_dict.items():
            error = params_info.get('avg_prediction_error', float('inf'))
            if error < best_error:
                best_error = error
                best_trend_type = trend_type

    if best_trend_type is None or best_trend_type not in TREND_FUNCTIONS:
        print(f"Warning: Best trend type '{best_trend_type}' not found in TREND_FUNCTIONS")
        return None, None, None, trend_params_dict

    # Get the trend function and parameters
    trend_func = TREND_FUNCTIONS[best_trend_type]['func']
    param_names = TREND_FUNCTIONS[best_trend_type]['params']
    param_dict = trend_params_dict[best_trend_type]['param_dict']

    # Convert param_dict to numpy array in the correct order
    params = np.array([param_dict[name] for name in param_names])

    print(f"Loaded global trend from: {os.path.basename(trend_dir)}")
    print(f"  Best trend type: {best_trend_type}")
    print(f"  Parameters: {param_dict}")
    print(f"  Avg prediction error: {trend_params_dict[best_trend_type].get('avg_prediction_error', 'N/A')}")

    return trend_func, params, best_trend_type, trend_params_dict


def compute_total_fidelity(
    trajectories: List[TrajectoryData]
) -> float:
    """
    Compute the total fidelity for a set of trajectories.

    f_fid(E) = Σ_{ξ ∈ E} fid(ξ)

    This is a modular function as specified in Section 4.3 of Methods.md.

    Args:
        trajectories: List of TrajectoryData objects with fidelity computed

    Returns:
        Total fidelity score (unnormalized)
    """
    if not trajectories:
        return 0.0
    return sum(t.fidelity for t in trajectories)


def compute_kmeans_partitions(
    trajectories: List[TrajectoryData],
    num_clusters: int,
    random_state: int = 42
) -> Tuple[List[TrajectoryData], int]:
    """
    Partition trajectories into clusters using K-Means on their embeddings.

    Args:
        trajectories: List of TrajectoryData objects with embeddings
        num_clusters: Number of clusters (m) for K-Means
        random_state: Random seed for K-Means

    Returns:
        Tuple of (trajectories with cluster_label set, actual number of clusters used)
    """
    # Collect embeddings
    embeddings = []
    valid_indices = []
    for i, t in enumerate(trajectories):
        if t.embedding is not None:
            embeddings.append(t.embedding)
            valid_indices.append(i)

    if len(embeddings) == 0:
        return trajectories, 0

    embeddings_array = np.array(embeddings)

    # Adjust num_clusters if we have fewer samples than requested clusters
    actual_num_clusters = min(num_clusters, len(embeddings))

    # Run K-Means
    kmeans = KMeans(n_clusters=actual_num_clusters, random_state=random_state, n_init=10)
    labels = kmeans.fit_predict(embeddings_array)

    # Assign cluster labels to trajectories
    for idx, label in zip(valid_indices, labels):
        trajectories[idx].cluster_label = int(label)

    return trajectories, actual_num_clusters


def compute_cluster_diversity(
    trajectories: List[TrajectoryData],
    num_clusters: int
) -> float:
    """
    Compute diversity using cluster-based square root formulation from Methods.md Section 4.3.

    f_div(S) = sum_{j=1}^m sqrt(|S ∩ P_j|)

    This formulation is monotone submodular due to the concave square root function,
    which captures diminishing returns as more samples from the same cluster are added.

    Args:
        trajectories: List of TrajectoryData objects with cluster_label set
        num_clusters: Total number of clusters (m)

    Returns:
        Diversity score (unnormalized)
    """
    if not trajectories or num_clusters == 0:
        return 0.0

    # Count samples per cluster
    cluster_counts = [0] * num_clusters
    for t in trajectories:
        if t.cluster_label is not None and 0 <= t.cluster_label < num_clusters:
            cluster_counts[t.cluster_label] += 1

    # Sum of square roots of counts
    diversity = sum(np.sqrt(count) for count in cluster_counts)
    return diversity


def compute_normalization_constants(
    all_trajectories: List[TrajectoryData],
    num_clusters: int
) -> Tuple[float, float]:
    """
    Compute normalization constants for fidelity and diversity.

    These are the values when the full set V is selected:
    - fid_max = f_fid(V) = sum of all fidelities
    - div_max = f_div(V) = sum_{j=1}^m sqrt(|P_j|)

    Normalizing by these ensures F(empty) = 0 and F(full) = 1.

    Args:
        all_trajectories: All candidate trajectories with cluster_label set
        num_clusters: Number of K-Means clusters (m)

    Returns:
        Tuple of (fid_max, div_max)
    """
    # fid_max: sum of all fidelities
    fid_max = sum(t.fidelity for t in all_trajectories)

    # div_max: f_div(V) = sum_{j=1}^m sqrt(|P_j|)
    div_max = compute_cluster_diversity(all_trajectories, num_clusters)

    # Avoid division by zero
    if fid_max == 0:
        fid_max = 1.0
    if div_max == 0:
        div_max = 1.0

    return fid_max, div_max


def compute_cosine_similarity(emb1: np.ndarray, emb2: np.ndarray) -> float:
    """Compute cosine similarity between two embeddings."""
    dot_product = np.dot(emb1, emb2)
    norm1 = np.linalg.norm(emb1)
    norm2 = np.linalg.norm(emb2)

    if norm1 == 0 or norm2 == 0:
        return 0.0

    return dot_product / (norm1 * norm2)


def compute_pairwise_diversity(embeddings: List[np.ndarray]) -> float:
    """
    Compute the total pairwise diversity for a set of embeddings.

    Diversity = sum_{i,j, i != j} (1 - cosine_similarity(emb_i, emb_j))

    Args:
        embeddings: List of embedding vectors

    Returns:
        Total pairwise diversity score
    """
    if len(embeddings) < 2:
        return 0.0

    total_diversity = 0.0
    for i in range(len(embeddings)):
        for j in range(i + 1, len(embeddings)):
            similarity = compute_cosine_similarity(embeddings[i], embeddings[j])
            diversity = 1.0 - similarity
            total_diversity += 2 * diversity  # Count both (i,j) and (j,i)

    return total_diversity


def compute_objective_function(
    trajectories: List[TrajectoryData],
    lambda_weight: float = 0.5,
    num_clusters: int = 5,
    fid_normalizer: float = 1.0,
    div_normalizer: float = 1.0
) -> float:
    """
    Compute the normalized objective function F(S) for a set of trajectories.

    F(S) = (1 - lambda) * (f_fid(S) / fid_max) + lambda * (f_div(S) / div_max)

    This is a convex combination of normalized fidelity and diversity as specified
    in Methods.md Section 4.3. Normalization ensures F(empty) = 0 and F(full) = 1.

    Args:
        trajectories: List of TrajectoryData objects with cluster_label set
        lambda_weight: Weight for diversity term in [0, 1]
        num_clusters: Number of clusters (m) for diversity computation
        fid_normalizer: Normalization constant for fidelity (fid_max)
        div_normalizer: Normalization constant for diversity (div_max)

    Returns:
        Normalized objective function value in [0, 1]
    """
    if not trajectories:
        return 0.0

    # Total fidelity: f_fid(S) = sum_{V in S} fid(V)
    total_fidelity = sum(t.fidelity for t in trajectories)

    # Cluster-based diversity: f_div(S) = sum_{j=1}^m sqrt(|S ∩ P_j|)
    total_diversity = compute_cluster_diversity(trajectories, num_clusters)

    # Normalized convex combination
    norm_fidelity = total_fidelity / fid_normalizer
    norm_diversity = total_diversity / div_normalizer

    return (1 - lambda_weight) * norm_fidelity + lambda_weight * norm_diversity


def compute_marginal_gain(
    current_set: List[TrajectoryData],
    candidate: TrajectoryData,
    lambda_weight: float = 0.5,
    num_clusters: int = 5,
    fid_normalizer: float = 1.0,
    div_normalizer: float = 1.0
) -> float:
    """
    Compute normalized marginal gain of adding a candidate to the current set.

    Delta(V) = F(S U {V}) - F(S)
             = (1 - lambda) * (fid(V) / fid_max)
               + lambda * ((sqrt(|S ∩ P_j| + 1) - sqrt(|S ∩ P_j|)) / div_max)

    where j is the cluster that candidate V belongs to.

    Args:
        current_set: Current selected trajectories with cluster_label set
        candidate: Candidate trajectory to add
        lambda_weight: Weight for diversity term in [0, 1]
        num_clusters: Number of clusters (m)
        fid_normalizer: Normalization constant for fidelity (fid_max)
        div_normalizer: Normalization constant for diversity (div_max)

    Returns:
        Normalized marginal gain value
    """
    # Normalized fidelity gain
    fidelity_gain = candidate.fidelity / fid_normalizer

    # Normalized diversity gain: (sqrt(|S ∩ P_j| + 1) - sqrt(|S ∩ P_j|)) / div_max
    diversity_gain = 0.0
    if candidate.cluster_label is not None and 0 <= candidate.cluster_label < num_clusters:
        # Count current samples in the candidate's cluster
        cluster_j = candidate.cluster_label
        current_count = sum(1 for t in current_set
                           if t.cluster_label == cluster_j)
        # Marginal diversity gain from adding one more to this cluster
        diversity_gain = (np.sqrt(current_count + 1) - np.sqrt(current_count)) / div_normalizer

    # Convex combination
    return (1 - lambda_weight) * fidelity_gain + lambda_weight * diversity_gain


def greedy_select_trajectories(
    all_trajectories: List[TrajectoryData],
    k: int = 5,
    lambda_weight: float = 0.5,
    num_clusters: int = 5,
    fid_normalizer: float = 1.0,
    div_normalizer: float = 1.0,
    verbose: bool = True
) -> Tuple[List[TrajectoryData], List[Dict]]:
    """
    Greedy algorithm to select k trajectories that maximize normalized F(S).

    Algorithm 1 from Methods.md Section 4.3:
    1. Initialize S = {}
    2. While |S| < k:
       - For all V in V \ S, calculate marginal gain Delta(V) = F(S U {V}) - F(S)
       - Select V* = argmax Delta(V)
       - Update S = S U {V*}

    The objective F(S) = (1 - lambda) * f_fid_norm(S) + lambda * f_div_norm(S) is monotone
    submodular, guaranteeing that this greedy algorithm achieves at least
    (1 - 1/e) of the optimal solution. Normalization ensures F(empty) = 0 and F(full) = 1.

    Args:
        all_trajectories: Pool of candidate trajectories with cluster_label set
        k: Number of trajectories to select
        lambda_weight: Weight for diversity term in [0, 1]
        num_clusters: Number of K-Means clusters (m) for diversity computation
        fid_normalizer: Normalization constant for fidelity (fid_max)
        div_normalizer: Normalization constant for diversity (div_max)
        verbose: Whether to print progress

    Returns:
        Tuple of (selected trajectories, selection history with normalized scores)
    """
    if len(all_trajectories) == 0:
        return [], []

    k = min(k, len(all_trajectories))
    selected = []
    remaining_indices = set(range(len(all_trajectories)))
    selection_history = []

    for iteration in range(k):
        best_gain = float('-inf')
        best_idx = None

        # Find the trajectory with maximum marginal gain
        for idx in remaining_indices:
            candidate = all_trajectories[idx]
            gain = compute_marginal_gain(
                selected, candidate, lambda_weight, num_clusters,
                fid_normalizer, div_normalizer
            )

            if gain > best_gain:
                best_gain = gain
                best_idx = idx

        if best_idx is None:
            break

        # Add best trajectory to selected set
        best_trajectory = all_trajectories[best_idx]
        selected.append(best_trajectory)
        remaining_indices.remove(best_idx)

        # Compute current objective value (normalized)
        current_f = compute_objective_function(
            selected, lambda_weight, num_clusters, fid_normalizer, div_normalizer
        )

        # Compute normalized fidelity and diversity
        total_fidelity_raw = sum(t.fidelity for t in selected)
        total_diversity_raw = compute_cluster_diversity(selected, num_clusters)
        norm_fidelity = total_fidelity_raw / fid_normalizer
        norm_diversity = total_diversity_raw / div_normalizer

        # Record selection history
        history_entry = {
            'iteration': iteration + 1,
            'prompt_idx': best_trajectory.prompt_idx,
            'cluster_label': best_trajectory.cluster_label,
            'marginal_gain': best_gain,
            'total_f': current_f,
            'total_fidelity': norm_fidelity,  # Normalized
            'total_diversity': norm_diversity,  # Normalized
            'fidelity': best_trajectory.fidelity,
            'score_first': best_trajectory.scores[0] if best_trajectory.scores else None,
            'score_last': best_trajectory.scores[-1] if best_trajectory.scores else None,
        }
        selection_history.append(history_entry)

        if verbose:
            cluster_info = f", Cluster: {best_trajectory.cluster_label}" if best_trajectory.cluster_label is not None else ""
            print(f"  Iteration {iteration + 1}/{k}: Selected prompt {best_trajectory.prompt_idx}{cluster_info}")
            print(f"    Marginal Gain: {best_gain:.4f}")
            print(f"    Running Total - F(S): {current_f:.4f}, Fid_norm: {norm_fidelity:.4f}, Div_norm: {norm_diversity:.4f}")

    return selected, selection_history


def format_trajectory_summary(
    trajectory: TrajectoryData,
    checkpoint_names: Optional[List[str]] = None,
    max_response_length: int = 500
) -> str:
    """
    Format a single trajectory for display in the output file.

    Args:
        trajectory: TrajectoryData object
        checkpoint_names: Optional list of checkpoint names for labeling
        max_response_length: Maximum characters to show per response

    Returns:
        Formatted string
    """
    lines = []
    lines.append("=" * 80)
    cluster_info = f", Cluster: {trajectory.cluster_label}" if trajectory.cluster_label is not None else ""
    lines.append(f"PROMPT (Index: {trajectory.prompt_idx}{cluster_info})")
    lines.append("-" * 40)
    lines.append(trajectory.prompt[:1000] + ("..." if len(trajectory.prompt) > 1000 else ""))
    lines.append("")
    lines.append(f"TREND FIDELITY: {trajectory.fidelity:.4f}")
    lines.append(f"Score Trajectory: {' -> '.join([f'{s:.2f}' for s in trajectory.scores])}")
    lines.append("")

    for i, (response, score) in enumerate(zip(trajectory.responses, trajectory.scores)):
        checkpoint_label = checkpoint_names[i] if checkpoint_names else f"Checkpoint {i+1}"
        lines.append(f"--- {checkpoint_label} (Score: {score:.2f}) ---")
        truncated = response[:max_response_length] + ("..." if len(response) > max_response_length else "")
        lines.append(truncated)
        lines.append("")

    return "\n".join(lines)


def save_objective_summary(
    selected_trajectories: List[TrajectoryData],
    selection_history: List[Dict],
    objective_name: str,
    output_path: str,
    checkpoint_names: Optional[List[str]] = None,
    lambda_weight: float = 0.5,
    num_clusters: int = 5,
    fid_normalizer: float = 1.0,
    div_normalizer: float = 1.0,
    best_trend_type: Optional[str] = None
) -> None:
    """
    Save the objective summary to a formatted text file.

    Args:
        selected_trajectories: List of selected TrajectoryData objects
        selection_history: List of selection history dictionaries
        objective_name: Name of the objective being summarized
        output_path: Path to save the output file
        checkpoint_names: Optional list of checkpoint names
        lambda_weight: Lambda weight used in selection
        num_clusters: Number of K-Means clusters used
        fid_normalizer: Normalization constant for fidelity
        div_normalizer: Normalization constant for diversity
        best_trend_type: Name of the global trend type used for fidelity
    """
    with open(output_path, 'w') as f:
        # Header
        f.write("=" * 80 + "\n")
        f.write(f"OBJECTIVE SUMMARY: {objective_name}\n")
        f.write("=" * 80 + "\n\n")

        f.write(f"Number of selected trajectories: {len(selected_trajectories)}\n")
        f.write(f"Lambda weight (diversity trade-off): {lambda_weight}\n")
        f.write(f"Number of K-Means clusters: {num_clusters}\n")
        f.write(f"Fidelity normalizer (fid_max): {fid_normalizer:.4f}\n")
        f.write(f"Diversity normalizer (div_max): {div_normalizer:.4f}\n")
        if best_trend_type:
            f.write(f"Global trend type: {best_trend_type}\n")
        f.write(f"Objective: F(S) = (1 - lambda) * (f_fid / fid_max) + lambda * (f_div / div_max)\n")
        f.write(f"Fidelity formula: fid(ξ) = exp(- Σ(u_t - f*(t))^2)\n")
        f.write(f"Note: All values are normalized so F(empty) = 0 and F(full) = 1\n\n")

        # Selection summary
        f.write("-" * 80 + "\n")
        f.write("SELECTION SUMMARY (all values normalized)\n")
        f.write("-" * 80 + "\n\n")

        f.write(f"{'Iter':<6} {'Prompt':<8} {'Clust':<7} {'Marg.Gain':<11} {'Fid_norm':<11} {'Div_norm':<11} {'F(S)':<10}\n")
        f.write("-" * 75 + "\n")

        for entry in selection_history:
            cluster_label = entry.get('cluster_label', 'N/A')
            cluster_str = str(cluster_label) if cluster_label is not None else 'N/A'
            f.write(f"{entry['iteration']:<6} "
                   f"{entry['prompt_idx']:<8} "
                   f"{cluster_str:<7} "
                   f"{entry['marginal_gain']:<11.4f} "
                   f"{entry['total_fidelity']:<11.4f} "
                   f"{entry['total_diversity']:<11.4f} "
                   f"{entry['total_f']:<10.4f}\n")

        f.write("\n")

        # Individual trajectories
        f.write("=" * 80 + "\n")
        f.write("SELECTED TRAJECTORIES (in order of selection)\n")
        f.write("=" * 80 + "\n\n")

        for i, trajectory in enumerate(selected_trajectories):
            f.write(f"\n[{i+1}/{len(selected_trajectories)}]\n")
            f.write(format_trajectory_summary(trajectory, checkpoint_names))
            f.write("\n")

    print(f"Objective summary saved to: {output_path}")
