#!/usr/bin/env python3
"""
Fixed Objectives Discovery Module.

This module implements two objective discovery methods:
1. ObtainFixedObjectives: Generates a fixed list of objectives
2. FixedObjectivesDiscovery: Selects best objectives from a pre-defined fixed list
"""

import random
import time
import numpy as np
from typing import List, Dict, Tuple, Any, Optional, Set
from src.objectives_discovery import BaseObjectivesDiscovery
from src.constants import OBJECTIVE_DISCOVERY_PROMPT, OBJECTIVE_DISCOVERY_WITH_EXISTING_PROMPT, DATASET_NAMES_DICT
from src.calc_objectives_fit import ObjectivesFit
from src.model_generation import apply_chat_template_to_prompt
import logging


class ObtainFixedObjectives(BaseObjectivesDiscovery):
    """
    Fixed objectives discovery method.

    This class implements a simplified approach to objective discovery:
    - Randomly samples prompts from the dataset
    - Generates trajectories for these prompts
    - Proposes objectives from the trajectories
    - Verifies the objectives
    - Keeps valid objectives until k objectives are discovered
    """

    def __init__(
        self,
        dataset: List[Dict[str, str]],
        model_sequence: List[str],
        k: int = 10,
        num_parallel_trajectories: int = 3,
        objectives_per_trajectory: int = 3,
        **kwargs
    ):
        """
        Initialize the fixed objectives discovery method.

        Args:
            dataset: Dataset to sample from
            model_sequence: List of model checkpoints [π_θ_1, ..., π_θ_T]
            k: Number of objectives to discover
            num_parallel_trajectories: Number of trajectories to process at once
            objectives_per_trajectory: Number of objectives to request per trajectory
            **kwargs: Additional arguments for base class
        """
        super().__init__(dataset, model_sequence, k, **kwargs)

        self.num_parallel_trajectories = num_parallel_trajectories
        self.objectives_per_trajectory = objectives_per_trajectory

        # Track discovered and rejected objectives
        self.all_proposed_objectives = set()

    # def _format_single_trajectory(
    #     self,
    #     prompt: str,
    #     responses: List[str],
    #     trajectory_num: int
    # ) -> str:
    #     """Format a single trajectory for inclusion in batch prompt."""
    #     trajectory_text = f"==== TRAJECTORY {trajectory_num} ====\n"
    #     trajectory_text += f"Input Prompt: {prompt}\n\n"

    #     for i, response in enumerate(responses):
    #         trajectory_text += f"Model Iteration {i+1} Response:\n{response}\n\n"

    #     return trajectory_text

    # def _create_batch_discovery_prompt(
    #     self,
    #     batch_trajectories: List[str],
    #     existing_objectives: List[str]
    # ) -> str:
    #     """Create prompt for proposer to discover objectives from multiple trajectories."""
    #     # Combine all trajectories
    #     trajectories_text = "\n".join(batch_trajectories)

    #     # Prepare existing objectives section if needed
    #     existing_objectives_section = ""
    #     if existing_objectives:
    #         existing_obj_list = "\n".join([f"- {obj}" for obj in existing_objectives])
    #         existing_objectives_section = OBJECTIVE_DISCOVERY_WITH_EXISTING_PROMPT.format(
    #             existing_objectives=existing_obj_list
    #         )

    #     # Use the prompt template from constants
    #     full_prompt = OBJECTIVE_DISCOVERY_PROMPT.format(
    #         trajectory_count=len(batch_trajectories),
    #         trajectories=trajectories_text,
    #         num_objectives=self.objectives_per_trajectory,
    #         existing_objectives_section=existing_objectives_section
    #     )

    #     return full_prompt

    def _discover_objectives_from_samples(
        self,
        samples: List[Dict[str, str]],
        existing_objectives: List[str]
    ) -> List[str]:
        """
        Discover objectives from a batch of samples.

        Args:
            samples: List of sample prompts
            existing_objectives: Already discovered objectives

        Returns:
            List of proposed objectives
        """
        # Extract prompts from samples
        batch_prompts = [sample['input'] for sample in samples]

        print(f"Generating trajectories for {len(batch_prompts)} prompts...")

        # Generate responses for each model in sequence using batched generation
        all_model_responses = []
        for model_idx, model_path in enumerate(self.model_sequence):
            print(f"  Generating responses from model {model_idx + 1}/{len(self.model_sequence)}")
            try:
                # Generate responses for all prompts with this model
                model_responses = self.generate_responses_batched(
                    model_path=model_path,
                    prompts=batch_prompts,
                    # max_new_tokens=1024,
                    max_new_tokens=512,
                    batch_size=8
                )
                all_model_responses.append(model_responses)
            except Exception as e:
                print(f"    Error generating responses from {model_path}: {e}")
                if self.logger:
                    self.logger.error(f"Failed to generate responses from {model_path}: {e}")
                # Use empty responses for failed model
                all_model_responses.append([""] * len(batch_prompts))

        # Reorganize responses: from [model][sample] to [sample][model]
        batch_trajectories = []

        for sample_idx in range(len(batch_prompts)):
            prompt = batch_prompts[sample_idx]
            # Collect responses from all models for this sample
            responses = [all_model_responses[model_idx][sample_idx]
                       for model_idx in range(len(self.model_sequence))]

            # Filter out empty responses
            responses = [r for r in responses if r]

            if not responses:
                print(f"  No valid responses for sample {sample_idx + 1}")
                continue

            # Format trajectory for inclusion in batch prompt
            # Apply chat template to properly format multi-turn conversations
            prompt_text = apply_chat_template_to_prompt(
                model_path=self.model_sequence[0],  # Use first model's tokenizer for consistency
                prompt=prompt,
                max_length=2000  # Reasonable limit to prevent extremely long prompts
            )
            trajectory_text = self._format_single_trajectory(
                prompt_text,
                responses,
                trajectory_num=len(batch_trajectories) + 1
            )
            batch_trajectories.append(trajectory_text)
        # Skip if no valid trajectories
        if not batch_trajectories:
            print("  No valid trajectories generated")
            return []

        # Create combined prompt for batch
        proposer_prompt = self._create_batch_discovery_prompt(
            batch_trajectories,
            existing_objectives
        )

        # Get objective proposals for batch
        print(f"Proposing objectives from {len(batch_trajectories)} trajectories...")
        if self.use_api_proposer:
            proposals = self._propose_objectives_with_api(
                proposer_prompt,
                self.objectives_per_trajectory * len(batch_trajectories)
            )
        else:
            proposals = self._propose_objectives_with_local_model(
                proposer_prompt,
                self.objectives_per_trajectory * len(batch_trajectories)
            )

        return proposals

    def obtain_objectives(self) -> Tuple[List[str], Dict[str, Any]]:
        """
        Main method implementing the fixed objective discovery algorithm.

        Returns:
            Tuple of (discovered_objectives, statistics)
        """
        start_time = time.time()

        print("\n" + "="*60)
        print("FIXED OBJECTIVES DISCOVERY")
        print("="*60)
        print(f"Target: {self.k} valid objectives")
        print(f"Dataset size: {len(self.dataset)} samples")
        print(f"Model sequence: {len(self.model_sequence)} models")
        print(f"Trajectories per batch: {self.num_parallel_trajectories}")
        print(f"Objectives per trajectory: {self.objectives_per_trajectory}")
        print("="*60)

        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "STARTING FIXED OBJECTIVES DISCOVERY".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"Configuration:")
            self.logger.info(f"  Target objectives (k): {self.k}")
            self.logger.info(f"  Dataset size: {len(self.dataset)} samples")
            self.logger.info(f"  Model sequence: {len(self.model_sequence)} models")
            self.logger.info(f"  Trajectories per batch: {self.num_parallel_trajectories}")
            self.logger.info(f"  Objectives per trajectory: {self.objectives_per_trajectory}")
            self.logger.info("")

        iteration = 0
        max_iterations = 100  # Safety limit to prevent infinite loops

        while len(self.discovered_objectives) < self.k and iteration < max_iterations:
            iteration += 1
            self.discovery_stats['total_iterations'] = iteration

            print(f"\n{'='*60}")
            print(f"ITERATION {iteration}")
            print(f"Current objectives: {len(self.discovered_objectives)}/{self.k}")
            print(f"{'='*60}")

            if self.logger:
                self.logger.info("\n" + "*"*80)
                self.logger.info("*" + f"ITERATION {iteration}".center(78) + "*")
                self.logger.info("*"*80)
                self.logger.info(f"Progress: {len(self.discovered_objectives)}/{self.k} objectives discovered")
                if self.discovered_objectives:
                    self.logger.info("Currently discovered objectives:")
                    for idx, obj in enumerate(self.discovered_objectives, 1):
                        self.logger.info(f"  {idx}. {obj}")
                self.logger.info("")

            # Step 1: Sample random prompts
            print(f"\nSampling {self.num_parallel_trajectories} random prompts...")
            sampled_prompts = random.sample(
                self.dataset,
                min(self.num_parallel_trajectories, len(self.dataset))
            )

            # Step 2: Discover objectives from these prompts
            proposals = self._discover_objectives_from_samples(
                sampled_prompts,
                self.discovered_objectives
            )

            if not proposals:
                print("No objectives proposed in this iteration")
                continue

            print(f"\nProposed {len(proposals)} objectives")
            self.discovery_stats['total_proposals'] += len(proposals)
            self.all_proposed_objectives.update(proposals)

            # Step 3: Verify each proposed objective
            print("\nVerifying proposed objectives...")
            for proposal_idx, objective in enumerate(proposals):
                # Skip if we already have enough objectives
                if len(self.discovered_objectives) >= self.k:
                    print(f"\n✓ Target of {self.k} objectives reached!")
                    break

                # Skip if this objective was already discovered
                if objective in self.discovered_objectives:
                    print(f"  Skipping duplicate: {objective[:50]}...")
                    continue

                # Skip if this objective was already rejected
                if any(obj['objective'] == objective for obj in self.rejected_objectives):
                    print(f"  Skipping previously rejected: {objective[:50]}...")
                    continue

                print(f"\nVerifying objective {proposal_idx + 1}/{len(proposals)}: {objective[:50]}...")

                # Verify the objective
                is_valid, verification_details = self._verify_objective(objective)

                if is_valid:
                    print(f"  ✓ ACCEPTED: Interpretability={verification_details['interpretability_score']:.2f}, Trend={verification_details['trend_type']}")

                    if self.logger:
                        self.logger.info(f"\n✓ OBJECTIVE ACCEPTED: {objective}")
                        self.logger.info(f"  Interpretability: {verification_details['interpretability_score']:.4f}")
                        self.logger.info(f"  Trend: {verification_details['trend_type']} (error: {verification_details['trend_error']:.4f})")

                    self.discovered_objectives.append(objective)
                else:
                    print(f"  ✗ REJECTED: ", end="")
                    if not verification_details['interpretable']:
                        print(f"Interpretability={verification_details['interpretability_score']:.2f} ", end="")
                    if not verification_details['follows_trend']:
                        print(f"Trend_error={verification_details['trend_error']:.2f}", end="")
                    print()

                    if self.logger:
                        self.logger.info(f"\n✗ OBJECTIVE REJECTED: {objective}")
                        if not verification_details['interpretable']:
                            self.logger.info(f"  Failed interpretability: {verification_details['interpretability_score']:.4f}")
                        if not verification_details['follows_trend']:
                            self.logger.info(f"  Failed trend: {verification_details['trend_error']:.4f}")

                    self.rejected_objectives.append({
                        'objective': objective,
                        'reason': verification_details
                    })

        # Calculate final statistics
        self.discovery_stats['time_elapsed'] = time.time() - start_time
        self.discovery_stats['discovered_count'] = len(self.discovered_objectives)
        self.discovery_stats['rejected_count'] = len(self.rejected_objectives)
        self.discovery_stats['acceptance_rate'] = (
            self.discovery_stats['discovered_count'] / self.discovery_stats['total_proposals']
            if self.discovery_stats['total_proposals'] > 0 else 0.0
        )

        # Final summary
        print("\n" + "="*60)
        print("DISCOVERY COMPLETE")
        print("="*60)
        print(f"Objectives discovered: {len(self.discovered_objectives)}/{self.k}")
        print(f"Total proposals: {self.discovery_stats['total_proposals']}")
        print(f"Rejected: {self.discovery_stats['rejected_count']}")
        print(f"Acceptance rate: {self.discovery_stats['acceptance_rate']:.2%}")
        print(f"Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")

        if self.discovered_objectives:
            print("\nDiscovered objectives:")
            for i, obj in enumerate(self.discovered_objectives, 1):
                print(f"  {i}. {obj}")

        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "FIXED DISCOVERY COMPLETE".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"Final statistics:")
            self.logger.info(f"  Objectives discovered: {len(self.discovered_objectives)}/{self.k}")
            self.logger.info(f"  Total proposals: {self.discovery_stats['total_proposals']}")
            self.logger.info(f"  Rejected: {self.discovery_stats['rejected_count']}")
            self.logger.info(f"  Acceptance rate: {self.discovery_stats['acceptance_rate']:.2%}")
            self.logger.info(f"  Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")
            self.logger.info("")

        return self.discovered_objectives, self.discovery_stats


class FixedObjectivesDiscovery(BaseObjectivesDiscovery):
    """
    Fixed objectives discovery method that selects from a pre-defined list.

    This class implements a discovery approach that:
    - Loads a fixed list of objectives from a file
    - Iteratively selects the best objectives from this fixed list
    - Uses the same selection and verification logic as ProposedObjectivesDiscovery
    - But skips the informative samples and candidate discovery phases
    """

    def __init__(
        self,
        dataset: List[Dict[str, str]],
        model_sequence: List[str],
        ground_truth_reward,  # RewardFunction instance
        fixed_objectives_filepath: str,
        num_fixed_objs: int,
        k: int = 10,
        num_samples_final_eval: int = 25,
        num_samples_select_best: int = 20,
        combination_function_type: str = 'linear_regression',
        combination_function_params: Optional[Dict[str, Any]] = None,
        train_test_split_idx: Optional[int] = None,
        max_iterations: int = 50,
        **kwargs
    ):
        """
        Initialize the fixed objectives discovery method.

        Args:
            dataset: Dataset to sample from
            model_sequence: List of model checkpoints [π_θ_1, ..., π_θ_T]
            ground_truth_reward: RewardFunction instance representing R*
            fixed_objectives_filepath: Path to file containing fixed objectives (one per line)
            num_fixed_objs: Number of objectives to use from the fixed list
            k: Number of objectives to ultimately select
            num_samples_select_best: Number of random samples for objective selection phase
            combination_function_type: Type of g function ('linear', 'linear_regression', etc.)
            combination_function_params: Parameters for combination function
            train_test_split_idx: Index to split model sequence for train/test
            max_iterations: Maximum iterations before stopping
            **kwargs: Additional arguments for base class
        """
        super().__init__(dataset, model_sequence, k, **kwargs)

        self.ground_truth_reward = ground_truth_reward
        self.fixed_objectives_filepath = fixed_objectives_filepath
        self.num_fixed_objs = num_fixed_objs
        self.num_samples_final_eval = num_samples_final_eval
        self.num_samples_select_best = num_samples_select_best
        self.combination_function_type = combination_function_type
        self.combination_function_params = combination_function_params or {}
        self.max_iterations = max_iterations

        # Set train-test split
        self.train_test_split_idx = train_test_split_idx
        if self.train_test_split_idx is None:
            self.train_test_split_idx = len(model_sequence) // 2

        # Import ObjectivesFit for calculating objective quality
        self.ObjectivesFit = ObjectivesFit

        # Track objectives
        self.iteration_history = []
        self.current_objectives = []

        # Load fixed objectives from file
        self.fixed_objectives = self._load_fixed_objectives()

    def _load_fixed_objectives(self) -> List[str]:
        """
        Load fixed objectives from file and select num_fixed_objs of them.

        Returns:
            List of fixed objectives to use for discovery
        """
        print(f"Loading fixed objectives from: {self.fixed_objectives_filepath}")

        try:
            with open(self.fixed_objectives_filepath, 'r') as f:
                all_objectives = [line.strip() for line in f if line.strip()]

            print(f"Loaded {len(all_objectives)} objectives from file")

            if self.logger:
                self.logger.info(f"Loaded {len(all_objectives)} objectives from {self.fixed_objectives_filepath}")

            # Select num_fixed_objs random objectives if we have more than needed
            if len(all_objectives) > self.num_fixed_objs:
                selected_objectives = random.sample(all_objectives, self.num_fixed_objs)
                print(f"Randomly selected {self.num_fixed_objs} objectives to use")
            else:
                selected_objectives = all_objectives
                print(f"Using all {len(selected_objectives)} objectives (requested {self.num_fixed_objs})")

            if self.logger:
                self.logger.info(f"Selected {len(selected_objectives)} objectives for discovery")
                self.logger.info("Fixed objectives pool:")
                for i, obj in enumerate(selected_objectives, 1):
                    self.logger.info(f"  {i}. {obj}")

            return selected_objectives

        except FileNotFoundError:
            raise FileNotFoundError(f"Fixed objectives file not found: {self.fixed_objectives_filepath}")
        except Exception as e:
            raise Exception(f"Error loading fixed objectives: {e}")

    def _select_best_objective(
        self,
        candidates: Set[str],
        current_objectives: List[str]
    ) -> Tuple[Optional[str], float]:
        """
        Select the best objective that maximizes Obj-Fit improvement.
        This is adapted from ProposedObjectivesDiscovery._select_best_objective

        Args:
            candidates: Set of candidate objectives
            current_objectives: Current discovered objectives

        Returns:
            Tuple of (best_objective, improvement_score)
        """
        print("\n--- Objective Selection (Obj-Error Evaluation) ---")
        if self.logger:
            self.logger.info("\n" + "="*60)
            self.logger.info("OBJECTIVE SELECTION (OBJ-ERROR EVALUATION)")
            self.logger.info("="*60)

        if not candidates:
            print("No candidates to evaluate")
            return None, 0.0

        # Sample new random samples for objective selection
        selection_samples = random.sample(
            self.dataset,
            min(self.num_samples_select_best, len(self.dataset))
        )
        print(f"Sampled {len(selection_samples)} random samples for objective selection")

        if self.logger:
            self.logger.info(f"Sampled {len(selection_samples)} random samples for objective selection")

        # Calculate baseline Obj-Error with current objectives
        baseline_error = 0.0
        obj_fit_calc = None
        if current_objectives:
            obj_fit_calc = self.ObjectivesFit(
                dataset=selection_samples,
                model_sequence=self.model_sequence,
                ground_truth_objective=self.ground_truth_reward,
                combination_function_type=self.combination_function_type,
                combination_function_params=self.combination_function_params,
                num_samples=len(selection_samples),
                train_test_split_idx=self.train_test_split_idx,
                scorer_model=self.scorer_model,
                device=self.device,
                cache_responses=True,
                use_different_prompts=False,
                dataset_type=DATASET_NAMES_DICT[self.dataset_name],
                use_detailed_rubric=True,
                batching=True,
                batch_size=8,
                model_cache_size=1,
                normalize_scores=True,
                logger=self.logger,
                max_concurrent=self.max_concurrent
            )
            baseline_error = obj_fit_calc.calculate(current_objectives)

        print(f"Baseline Obj-Error: {baseline_error:.4f}")
        print(f"Evaluating {len(candidates)} candidates...")

        if self.logger:
            self.logger.info(f"Baseline Obj-Error (with {len(current_objectives)} objectives): {baseline_error:.4f}")
            self.logger.info(f"Evaluating {len(candidates)} candidate objectives...")

        best_objective = None
        best_reduction = float('inf')
        obj_fit_results = []

        for i, candidate in enumerate(candidates):
            print(f"Evaluating candidate {i+1}/{len(candidates)}: {candidate[:50]}...")

            # Calculate Obj-Error with this candidate added
            test_objectives = current_objectives + [candidate]

            # Create new ObjectivesFit calculator if needed (first iteration)
            if (i == 0) and (not current_objectives):
                obj_fit_calc = self.ObjectivesFit(
                    dataset=selection_samples,
                    model_sequence=self.model_sequence,
                    ground_truth_objective=self.ground_truth_reward,
                    combination_function_type=self.combination_function_type,
                    combination_function_params=self.combination_function_params,
                    num_samples=len(selection_samples),
                    train_test_split_idx=self.train_test_split_idx,
                    scorer_model=self.scorer_model,
                    device=self.device,
                    cache_responses=True,
                    use_different_prompts=False,
                    dataset_type=DATASET_NAMES_DICT[self.dataset_name],
                    use_detailed_rubric=True,
                    batching=True,
                    batch_size=8,
                    model_cache_size=1,
                    normalize_scores=True,
                    logger=self.logger,
                    max_concurrent=self.max_concurrent
                )

            obj_error = obj_fit_calc.calculate(test_objectives)
            reduction = obj_error - baseline_error

            obj_fit_results.append({
                'objective': candidate,
                'obj_error': obj_error,
                'reduction': reduction
            })

            # Update best if this is better (lower obj-error = better)
            if reduction < best_reduction:
                best_reduction = reduction
                best_objective = candidate

            print(f"  Obj-Error: {obj_error:.4f}, Reduction: {reduction:.4f}")

        # Cleanup
        if obj_fit_calc:
            obj_fit_calc.cleanup_model_cache()

        if self.logger:
            # Sort results by reduction
            obj_fit_results.sort(key=lambda x: x['reduction'])

            self.logger.info("Obj-Error Results for All Candidates:")
            for idx, result in enumerate(obj_fit_results, 1):
                self.logger.info(f"  {idx}. Objective: {result['objective']}")
                self.logger.info(f"     Obj-Error: {result['obj_error']:.6f}")
                self.logger.info(f"     Reduction: {result['reduction']:.6f}")
                if result['objective'] == best_objective:
                    self.logger.info("     >>> SELECTED <<<")
                self.logger.info("")

        if best_objective:
            print(f"\nBest objective: {best_objective[:50]}...")
            print(f"Best reduction: {best_reduction:.4f}")

            if self.logger:
                self.logger.info("\n" + "-"*40)
                self.logger.info("SELECTED BEST OBJECTIVE:")
                self.logger.info(f"  Objective: {best_objective}")
                self.logger.info(f"  Obj-Error: {baseline_error + best_reduction:.6f}")
                self.logger.info(f"  Reduction over baseline: {best_reduction:.6f}")
                self.logger.info("-"*40)

        return best_objective, best_reduction

    def obtain_objectives(self) -> Tuple[List[str], Dict[str, Any]]:
        """
        Main method implementing the fixed objective discovery algorithm.

        Returns:
            Tuple of (discovered_objectives, statistics)
        """
        start_time = time.time()

        print("\n" + "="*60)
        print("FIXED OBJECTIVES DISCOVERY")
        print("="*60)
        print(f"Target: {self.k} objectives to select")
        print(f"Fixed objectives pool: {len(self.fixed_objectives)} objectives")
        print(f"Dataset size: {len(self.dataset)} samples")
        print(f"Model sequence: {len(self.model_sequence)} models")
        print(f"Train/Test split: {self.train_test_split_idx}")
        print(f"Combination function: {self.combination_function_type}")
        print(f"Num samples for selection: {self.num_samples_select_best}")
        print("="*60)

        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "STARTING FIXED OBJECTIVES DISCOVERY".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"Configuration:")
            self.logger.info(f"  Target objectives (k): {self.k}")
            self.logger.info(f"  Fixed objectives pool size: {len(self.fixed_objectives)}")
            self.logger.info(f"  Dataset size: {len(self.dataset)} samples")
            self.logger.info(f"  Model sequence: {len(self.model_sequence)} models")
            self.logger.info(f"  Train/Test split index: {self.train_test_split_idx}")
            self.logger.info(f"  Combination function: {self.combination_function_type}")
            self.logger.info(f"  Num samples for selection: {self.num_samples_select_best}")
            self.logger.info("")

        # Keep track of available objectives (not yet selected)
        available_objectives = set(self.fixed_objectives)
        iteration = 0

        while len(self.discovered_objectives) < self.k and iteration < self.max_iterations:
            iteration += 1
            self.discovery_stats['total_iterations'] = iteration

            # Check if we have any objectives left to select from
            remaining_candidates = available_objectives - set(self.discovered_objectives)
            if not remaining_candidates:
                print("\nNo more objectives available in the fixed pool")
                break

            print(f"\n{'='*60}")
            print(f"ITERATION {iteration}")
            print(f"Current objectives: {len(self.discovered_objectives)}/{self.k}")
            print(f"Remaining candidates: {len(remaining_candidates)}")
            print(f"{'='*60}")

            if self.logger:
                self.logger.info("\n" + "*"*80)
                self.logger.info("*" + f"ITERATION {iteration}".center(78) + "*")
                self.logger.info("*"*80)
                self.logger.info(f"Progress: {len(self.discovered_objectives)}/{self.k} objectives discovered")
                self.logger.info(f"Remaining candidates in pool: {len(remaining_candidates)}")
                if self.discovered_objectives:
                    self.logger.info("Currently discovered objectives:")
                    for idx, obj in enumerate(self.discovered_objectives, 1):
                        self.logger.info(f"  {idx}. {obj}")
                self.logger.info("")

            # Store current objectives for history
            iteration_start_objectives = self.discovered_objectives.copy()

            # STEP 1: SELECT BEST OBJECTIVE FROM FIXED POOL
            print("\n=== STEP 1: OBJECTIVE SELECTION ===")

            # Use remaining candidates for selection
            best_objective, improvement = self._select_best_objective(
                remaining_candidates,
                self.discovered_objectives
            )

            if best_objective is None:
                print("\nNo valid objective found in this iteration")
                continue

            # STEP 2: OBJECTIVE VERIFICATION
            print("\n=== STEP 2: OBJECTIVE VERIFICATION ===")

            is_valid, verification_details = self._verify_objective(best_objective)

            if is_valid:
                print(f"\n✓ OBJECTIVE ACCEPTED: {best_objective[:50]}...")
                print(f"  Interpretability score: {verification_details['interpretability_score']:.4f}")
                print(f"  Trend type: {verification_details['trend_type']}")
                print(f"  Trend error: {verification_details['trend_error']:.4f}")

                if self.logger:
                    self.logger.info("\n" + "="*60)
                    self.logger.info("✓ OBJECTIVE ACCEPTED")
                    self.logger.info("="*60)
                    self.logger.info(f"Objective: {best_objective}")
                    self.logger.info(f"Verification Summary:")
                    self.logger.info(f"  - Interpretability Score: {verification_details['interpretability_score']:.4f}")
                    self.logger.info(f"  - Trend Type: {verification_details['trend_type']}")
                    self.logger.info(f"  - Trend Error: {verification_details['trend_error']:.4f}")
                    self.logger.info(f"  - Improvement to Obj-Error: {improvement:.6f}")

                self.discovered_objectives.append(best_objective)
                self.current_objectives = self.discovered_objectives.copy()
                # Remove from available pool
                available_objectives.discard(best_objective)
            else:
                print(f"\n✗ OBJECTIVE REJECTED: {best_objective[:50]}...")
                if not verification_details['interpretable']:
                    print(f"  Failed interpretability check")
                if not verification_details['follows_trend']:
                    print(f"  Failed trend check")

                if self.logger:
                    self.logger.info("\n" + "="*60)
                    self.logger.info("✗ OBJECTIVE REJECTED")
                    self.logger.info("="*60)
                    self.logger.info(f"Objective: {best_objective}")
                    self.logger.info(f"Rejection Reasons:")
                    if not verification_details['interpretable']:
                        self.logger.info(f"  - Failed interpretability check (score: {verification_details['interpretability_score']:.4f})")
                    if not verification_details['follows_trend']:
                        self.logger.info(f"  - Failed trend check (error: {verification_details['trend_error']:.4f})")

                self.rejected_objectives.append({
                    'objective': best_objective,
                    'reason': verification_details
                })
                # Remove from available pool even if rejected
                available_objectives.discard(best_objective)

            # Track iteration history
            self.iteration_history.append({
                'iteration': iteration,
                'objectives_start': iteration_start_objectives,
                'objectives_end': self.discovered_objectives.copy(),
                'best_candidate': best_objective,
                'improvement': improvement,
                'accepted': is_valid
            })

            # Update objectives history for statistics
            self.objectives_per_iteration.append({
                'iteration': iteration,
                'objectives': self.discovered_objectives.copy(),
                'num_objectives': len(self.discovered_objectives),
                'candidates_evaluated': len(remaining_candidates)
            })

        # Calculate Final Obj-Error for discovered objectives
        final_obj_error, combiner_save_path = self.calculate_final_obj_error(
            discovered_objectives=self.discovered_objectives,
            ground_truth_reward=self.ground_truth_reward,
            num_samples_eval=self.num_samples_final_eval,
            train_test_split_idx=self.train_test_split_idx,
            combination_function_type=self.combination_function_type,
            combination_function_params=self.combination_function_params,
            save_dir=self.output_dir
        )

        # Calculate final statistics
        self.discovery_stats['time_elapsed'] = time.time() - start_time
        self.discovery_stats['discovered_count'] = len(self.discovered_objectives)
        self.discovery_stats['rejected_count'] = len(self.rejected_objectives)
        self.discovery_stats['total_proposals'] = len(self.fixed_objectives)
        self.discovery_stats['acceptance_rate'] = (
            self.discovery_stats['discovered_count'] / self.discovery_stats['total_proposals']
            if self.discovery_stats['total_proposals'] > 0 else 0.0
        )
        self.discovery_stats['final_obj_error'] = final_obj_error
        self.discovery_stats['reward_combiner_path'] = combiner_save_path

        # Final summary
        print("\n" + "="*60)
        print("DISCOVERY COMPLETE")
        print("="*60)
        print(f"Objectives discovered: {len(self.discovered_objectives)}/{self.k}")
        print(f"Total iterations: {iteration}")
        print(f"Rejected: {self.discovery_stats['rejected_count']}")
        print(f"Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")
        if final_obj_error is not None:
            print(f"Final Obj-Error: {final_obj_error:.6f}")

        if self.discovered_objectives:
            print("\nDiscovered objectives:")
            for i, obj in enumerate(self.discovered_objectives, 1):
                print(f"  {i}. {obj}")

        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "FIXED DISCOVERY COMPLETE".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"Final statistics:")
            self.logger.info(f"  Objectives discovered: {len(self.discovered_objectives)}/{self.k}")
            self.logger.info(f"  Total iterations: {iteration}")
            self.logger.info(f"  Rejected: {self.discovery_stats['rejected_count']}")
            self.logger.info(f"  Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")
            if final_obj_error is not None:
                self.logger.info(f"  Final Obj-Error: {final_obj_error:.6f}")
            self.logger.info("")

        return self.discovered_objectives, self.discovery_stats