"""
Objective Discovery Classes for Automatically Discovering Training Objectives

This module implements various approaches to discovering implicit objectives 
that models are optimizing for during training.
"""

import torch
import pickle
import numpy as np
import random
import json
from typing import List, Dict, Tuple, Optional, Union, Any, Set
from abc import ABC, abstractmethod
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from openai import OpenAI, AsyncOpenAI
import asyncio
import re
import os
from collections import defaultdict
import time
from tqdm import tqdm
import logging
from pprint import pformat

try:
    # Try relative imports first (when imported as a module)
    from .constants import (
        OPENAI_API_KEY,
        OBJECTIVE_DISCOVERY_PROMPT,
        OBJECTIVE_DISCOVERY_WITH_EXISTING_PROMPT,
        TRAJECTORY_RELEVANCE_PROMPT,
        DATASET_NAMES_DICT
    )
    from .objectives_verifiers import (
        HumanInterpretableVerifier,
        PredictableTrendVerifier
    )
    # from .setup_datasets import sample_to_input_dialogue_tldr, sample_to_input_dialogue_hh, load_normalized_dataset_samples
    from .setup_datasets_clean import load_normalized_dataset_samples
    from .calc_objectives_fit import ObjectivesFit
    from .objective_scorer import ObjectiveScorer
    from .model_generation import generate_responses_batched as generate_responses_batched_util
    from .model_generation import generate_huggingface_response
    from .model_generation import apply_chat_template_to_prompt
    from .reward_combiner_logging import log_test_reward_combiner
    from .prompt_optimize import PromptOptimizer, InterpretabilityFeedback, optimize_discovery_prompt
except ImportError:
    # Fall back to absolute imports (when run directly)
    from constants import (
        OPENAI_API_KEY,
        OBJECTIVE_DISCOVERY_PROMPT,
        OBJECTIVE_DISCOVERY_WITH_EXISTING_PROMPT,
        TRAJECTORY_RELEVANCE_PROMPT,
        DATASET_NAMES_DICT
    )
    from objectives_verifiers import (
        HumanInterpretableVerifier,
        PredictableTrendVerifier
    )
    # from setup_datasets import sample_to_input_dialogue_tldr, sample_to_input_dialogue_hh, load_normalized_dataset_samples
    from setup_datasets_clean import load_normalized_dataset_samples
    from calc_objectives_fit import ObjectivesFit
    from objective_scorer import ObjectiveScorer
    from model_generation import generate_responses_batched as generate_responses_batched_util
    from model_generation import generate_huggingface_response
    from model_generation import apply_chat_template_to_prompt
    from reward_combiner_logging import log_test_reward_combiner
    from prompt_optimize import PromptOptimizer, InterpretabilityFeedback, optimize_discovery_prompt

class BaseObjectivesDiscovery(ABC):
    """
    Abstract base class for objective discovery methods.
    All objective discovery algorithms should inherit from this class.
    """
    
    def __init__(
        self,
        dataset: Union[str, List[Dict[str, str]]],
        model_sequence: List[str],
        k: int = 10,
        verifier_epsilon_interpretable: float = 0.15,
        verifier_epsilon_trend: float = 0.1,
        verification_sample_size: int = 20,
        proposer_model: str = "gpt-4o-mini",
        scorer_model_name: str = "gpt-4o-mini",  # Renamed for clarity
        human_scorer_models: Optional[List[str]] = None,
        use_api_proposer: bool = True,
        device: str = "auto",
        output_dir: Optional[str] = None,
        logger: Optional[logging.Logger] = None,
        multi_turn: bool = False,
        group_scoring: bool = False,
        # group_scoring: bool = True
        # TextGrad prompt optimization settings
        use_prompt_optimization: bool = False,
        prompt_optimization_model: str = "gpt-4o-mini",
        max_prompt_length: int = None,
        base_model_name: str = None,
        max_concurrent: int = 20,
        data_dir: str = None
    ):
        """
        Initialize the base objective discovery class.

        Args:
            dataset: Either a HuggingFace dataset name/path or a list of samples
                    Each sample should have 'input'/'prompt' field
            model_sequence: List of paths to model checkpoints (pi_theta_1, ..., pi_theta_T)
                          These represent the training trajectory
            k: Number of objectives to discover
            verifier_epsilon_interpretable: Threshold for human-interpretability verification
            verifier_epsilon_trend: Threshold for predictable trend verification
            verification_sample_size: Number of samples to use for objective verification
            proposer_model: Model to use for proposing objectives
            scorer_model_name: Model name to use for scoring objectives
            human_scorer_models: Optional list of models for human scoring in verification (defaults to scorer_model_name)
            use_api_proposer: Whether to use API (e.g., OpenAI) for proposer
            device: Device to use for local models
            output_dir: Optional output directory for saving results and plots
            logger: Optional logger instance for logging
            max_prompt_length: Optional max token length for filtering prompts
            base_model_name: Model name for tokenizer (required if max_prompt_length is set)
            data_dir: Optional data directory for datasets that support it (e.g., 'helpful-base' for HH-RLHF)
        """
        self.k = k
        self.model_sequence = model_sequence
        self.verifier_epsilon_interpretable = verifier_epsilon_interpretable
        self.verifier_epsilon_trend = verifier_epsilon_trend
        self.verification_sample_size = verification_sample_size
        self.proposer_model = proposer_model
        self.scorer_model_name = scorer_model_name  # Now using the correct parameter name
        self.human_scorer_models = human_scorer_models if human_scorer_models else [scorer_model_name]
        self.use_api_proposer = use_api_proposer
        self.device = device
        self.output_dir = output_dir
        self.logger = logger or logging.getLogger(__name__)
        self.multi_turn = multi_turn
        self.group_scoring = group_scoring
        self.max_concurrent = max_concurrent
        self.data_dir = data_dir

        # Load dataset
        self.dataset_name = dataset
        self.dataset = self._load_dataset(dataset, multi_turn=multi_turn, max_prompt_length=max_prompt_length, base_model_name=base_model_name, data_dir=data_dir)
        
        # Initialize proposer (API client or local model)
        self.proposer_client = None
        self.proposer_local_model = None
        self.proposer_tokenizer = None
        self._initialize_proposer()
        
        # Initialize scorer model once for all components
        self.scorer_model = self._init_scorer_model()
        
        # Storage for discovered objectives
        # self.discovered_objectives = []
        self.discovered_objectives = [
            # 'improve engagement with simplified language',
            # 'enhance narrative creativity and imagination',
            # 'increase simplification for broader understanding',
            # 'improve clarity and precision in explanations'
            # 'enhance contextual relevance and engagement',
            # 'enhance child-friendly communication style',
        ]
        self.rejected_objectives = []
        self.objectives_per_iteration = []  # Track objectives at each iteration
        self.discovery_stats = {
            'total_proposals': 0,
            'total_iterations': 0,
            'verification_failures': {
                'interpretability': 0,
                'trend': 0
            },
            'time_elapsed': 0.0,
            'objectives_history': []  # History of objectives at each iteration
        }
        
        # Cache for model outputs
        self.model_outputs_cache = {}

        # Prompt optimization settings
        self.use_prompt_optimization = use_prompt_optimization
        self.prompt_optimization_model = prompt_optimization_model
        self.discovery_prompt = OBJECTIVE_DISCOVERY_PROMPT  # Modifiable copy

        if use_prompt_optimization and self.logger:
            self.logger.info("Prompt optimization enabled")
        
    def generate_responses_batched(
        self,
        model_path: str,
        prompts: List[Union[str, List[Dict[str, str]]]],
        # max_new_tokens: int = 1024,
        max_new_tokens: int = 512,
        batch_size: int = 8
    ) -> List[str]:
        """
        Generate responses for a batch of prompts using a specific model.
        Efficiently processes multiple prompts in batches.

        Args:
            model_path: Path to model checkpoint
            prompts: List of prompts (strings or message lists)
            max_new_tokens: Maximum tokens to generate
            batch_size: Batch size for generation

        Returns:
            List of responses (one per prompt)
        """
        if not prompts:
            return []

        # Check if this is an adapter model
        adapter_config_path = os.path.join(model_path, 'adapter_config.json')
        is_adapter = os.path.exists(adapter_config_path)

        if is_adapter:
            print(f"Loading adapter model from: {model_path}")
            # Load adapter config to get base model
            with open(adapter_config_path, 'r') as f:
                adapter_config = json.load(f)
            base_model_name = adapter_config.get('base_model_name_or_path', 'meta-llama/Llama-3.1-8B')

            # Check if tokenizer files exist in the adapter directory
            tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')

            if os.path.exists(tokenizer_config_path):
                tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            else:
                tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

            if tokenizer.pad_token is None:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

            # Load base model with quantization
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4'
            )

            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                quantization_config=bnb_config,
                torch_dtype=torch.bfloat16,
                device_map=self.device,
                trust_remote_code=True
            )

            if len(tokenizer) != base_model.config.vocab_size:
                print(f"Resizing model embeddings from {base_model.config.vocab_size} to {len(tokenizer)}")
                base_model.resize_token_embeddings(len(tokenizer))

            # Apply adapter
            model = PeftModel.from_pretrained(base_model, model_path)
        else:
            print(f"Loading full model from: {model_path}")
            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16,
                device_map=self.device
            )

        # Set pad token if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Use the centralized batched generation function
        responses = generate_responses_batched_util(
            model=model,
            tokenizer=tokenizer,
            prompts=prompts,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            temperature=0.7,
            top_p=0.9,
            return_model_and_tokenizer=False
        )

        # Clean up memory
        del model
        del tokenizer
        torch.cuda.empty_cache()

        return responses

    # COMMENTED OUT - Now using imported function from model_generation.py
    # def apply_chat_template_to_prompt(
    #     self,
    #     model_path: str,
    #     prompt: Union[str, List[Dict[str, str]]],
    #     max_length: Optional[int] = 2000
    # ) -> str:
    #     """
    #     Load tokenizer and apply chat template to a prompt without tokenizing.
    #
    #     This method properly handles multi-turn conversations by applying the model's
    #     chat template formatting, which is crucial for trajectory construction.
    #
    #     Args:
    #         model_path: Path to model checkpoint (used to load appropriate tokenizer)
    #         prompt: Either a string or a list of message dicts (e.g., [{"role": "user", "content": "..."}])
    #         max_length: Maximum character length. If exceeded, truncates from the left. Default: 2000.
    #
    #     Returns:
    #         String with chat template applied (and optionally truncated)
    #     """
    #     # Check if this is an adapter model
    #     adapter_config_path = os.path.join(model_path, 'adapter_config.json')
    #     is_adapter = os.path.exists(adapter_config_path)
    #
    #     if is_adapter:
    #         # Load adapter config to get base model
    #         with open(adapter_config_path, 'r') as f:
    #             adapter_config = json.load(f)
    #         base_model_name = adapter_config.get('base_model_name_or_path', 'meta-llama/Llama-3.1-8B')
    #
    #         # Check if tokenizer files exist in the adapter directory
    #         tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
    #
    #         if os.path.exists(tokenizer_config_path):
    #             tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    #         else:
    #             tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    #     else:
    #         # Load tokenizer from full model path
    #         tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    #
    #     # Set pad token if needed
    #     if tokenizer.pad_token is None:
    #         tokenizer.pad_token = tokenizer.eos_token
    #
    #     # Apply chat template based on prompt type
    #     if isinstance(prompt, list):
    #         # Multi-turn conversation format
    #         formatted_text = tokenizer.apply_chat_template(
    #             prompt,
    #             tokenize=False,
    #             add_generation_prompt=False
    #         )
    #     else:
    #         # String prompt - wrap in user message for consistent formatting
    #         formatted_text = tokenizer.apply_chat_template(
    #             [{"role": "user", "content": prompt}],
    #             tokenize=False,
    #             add_generation_prompt=False
    #         )
    #
    #     # Truncate from the left if max_length is specified and exceeded
    #     if max_length is not None and len(formatted_text) > max_length:
    #         formatted_text = formatted_text[-max_length:]
    #
    #     # Clean up tokenizer
    #     del tokenizer
    #
    #     return formatted_text

        # Original implementation commented out below:
        # print(f"Generating responses for {len(prompts)} prompts in batches of {batch_size}")
        #
        # # Process in batches
        # for i in range(0, len(prompts), batch_size):
        #     batch = prompts[i:i+batch_size]
        #
        #     # Prepare batch inputs
        #     batch_input_ids = []
        #     max_length = 0
        #
        #     for prompt in batch:
        #         if isinstance(prompt, list):
        #             # Apply chat template for message format
        #             input_ids = tokenizer.apply_chat_template(
        #                 prompt,
        #                 padding=False,
        #                 add_generation_prompt=True,
        #                 truncation=True,
        #                 max_length=1024
        #             )
        #         else:
        #             # Handle string prompt
        #             tokens = tokenizer(
        #                 prompt,
        #                 return_tensors="pt",
        #                 truncation=True,
        #                 max_length=1024
        #             )
        #             input_ids = tokens['input_ids'][0].tolist()
        #
        #         batch_input_ids.append(input_ids)
        #         max_length = max(max_length, len(input_ids))
        #
        #     # Pad all inputs to same length
        #     padded_inputs = []
        #     attention_masks = []
        #     for input_ids in batch_input_ids:
        #         padding_length = max_length - len(input_ids)
        #         padded_ids = input_ids + [tokenizer.pad_token_id or tokenizer.eos_token_id] * padding_length
        #         attention_mask = [1] * len(input_ids) + [0] * padding_length
        #         padded_inputs.append(padded_ids)
        #         attention_masks.append(attention_mask)
        #
        #     # Convert to tensors
        #     input_ids_tensor = torch.tensor(padded_inputs, device=model.device)
        #     attention_mask_tensor = torch.tensor(attention_masks, device=model.device)
        #
        #     # Generate for the batch
        #     with torch.no_grad():
        #         outputs = model.generate(
        #             input_ids=input_ids_tensor,
        #             attention_mask=attention_mask_tensor,
        #             max_new_tokens=max_new_tokens,
        #             do_sample=True,
        #             temperature=0.7,
        #             top_p=0.9,
        #             pad_token_id=tokenizer.eos_token_id
        #         )
        #
        #     # Decode responses
        #     for j, output in enumerate(outputs):
        #         response = tokenizer.decode(
        #             output[len(batch_input_ids[j]):],
        #             skip_special_tokens=True
        #         )
        #         responses.append(response)
        #
        # # Clean up memory
        # del model
        # del tokenizer
        # torch.cuda.empty_cache()
        #
        # return responses
    
    def _load_dataset(self, dataset: Union[str, List[Dict[str, str]]], multi_turn: bool = False, max_prompt_length: int = None, base_model_name: str = None, data_dir: str = None) -> List[Dict[str, str]]:
        """
        Load and normalize the dataset.

        Args:
            dataset: Either a dataset name/path or list of samples
            max_prompt_length: Optional max token length for filtering prompts
            base_model_name: Model name for tokenizer (required if max_prompt_length is set)
            data_dir: Optional data directory for datasets that support it (e.g., 'helpful-base' for HH-RLHF)

        Returns:
            List of normalized samples with 'input' field
        """
        if isinstance(dataset, str):
            normalized_samples = load_normalized_dataset_samples(dataset, multi_turn=multi_turn, max_prompt_length=max_prompt_length, base_model_name=base_model_name, data_dir=data_dir)
            return normalized_samples
            # # Load from HuggingFace datasets
            # if dataset == "openai/summarize_from_feedback":
            #     ds = load_dataset(dataset, 'comparisons', split='train')
            #     normalized = []
            #     for sample in ds:
            #         # Format Reddit TLDR data similar to test_ppo_orig.py
            #         info = sample.get("info", {})
            #         subreddit = info.get("subreddit", "")
            #         title = info.get("title", "")
            #         post = info.get("post", "")
                    
            #         # Create the formatted query string
            #         query_text = f"SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
                    
            #         # Store as a message format that can be used with chat_template later
            #         normalized_sample = {
            #             # 'input': [{"role": "user", "content": query_text}],
            #             'input': sample_to_input_dialogue_tldr(info),
            #             'raw_text': query_text  # Keep raw text for backward compatibility
            #         }
                    
            #         # Keep summaries if present (for reference)
            #         if 'summaries' in sample:
            #             normalized_sample['summaries'] = sample['summaries']
                    
            #         normalized.append(normalized_sample)
            #     return normalized
            # elif dataset == 'Anthropic/hh-rlhf':
            #     ds = load_dataset(dataset, split='train')
            #     normalized = []
            #     for sample in ds:
            #         chosen = sample.get('chosen', '')
            #         # Convert to multi-turn dialogue format
            #         dialogue = sample_to_input_dialogue_hh(chosen, multi_turn=False)
            #         normalized_sample = {
            #             'input': dialogue,
            #             'raw_text': chosen
            #         }
            #         normalized.append(normalized_sample)
            #     return normalized
            # else:
            #     raise ValueError(f"Invalid dataset path: {dataset}")
        else:
            # Already a list of samples, just normalize field names
            normalized = []
            for sample in dataset:
                normalized_sample = {
                    'input': sample.get('input', sample.get('prompt', sample.get('text', '')))
                }
                # Keep other fields if present
                for key in ['response', 'output', 'label']:
                    if key in sample:
                        normalized_sample[key] = sample[key]
                normalized.append(normalized_sample)
            return normalized
    
    def _initialize_proposer(self):
        """Initialize the proposer model (API client or local model)."""
        if self.use_api_proposer:
            # Initialize API client
            api_key = OPENAI_API_KEY or os.environ.get('OPENAI_API_KEY')
            if api_key:
                self.proposer_client = OpenAI(api_key=api_key)
                self.async_proposer_client = AsyncOpenAI(api_key=api_key)
            else:
                raise ValueError("No OpenAI API key found. Set OPENAI_API_KEY environment variable.")
        else:
            # Load local HuggingFace model
            print(f"Loading local proposer model: {self.proposer_model}")
            self.proposer_tokenizer = AutoTokenizer.from_pretrained(
                self.proposer_model,
                trust_remote_code=True
            )
            torch_dtype_map = {
                'bfloat16': torch.bfloat16,
                'float16': torch.float16,
                'float32': torch.float32
            }
            compute_dtype = torch_dtype_map.get(
                'bfloat16', torch.bfloat16
            )
            
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4',
                bnb_4bit_compute_dtype=compute_dtype
            )
            self.proposer_local_model = AutoModelForCausalLM.from_pretrained(
                self.proposer_model,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16,
                device_map=self.device,
                use_cache=False,
                attn_implementation='flash_attention_2',
                quantization_config=bnb_config,
            )
            
            # Set pad token if not present
            if self.proposer_tokenizer.pad_token is None:
                self.proposer_tokenizer.pad_token = self.proposer_tokenizer.eos_token
    
    def _init_scorer_model(self):
        """
        Initialize the ObjectiveScorer model once for reuse across all components.
        This avoids redundant model loading and improves memory efficiency.
        
        Returns:
            ObjectiveScorer: Initialized scorer model instance
        """
        # Determine if using API or local model
        use_api = self.scorer_model_name.startswith("gpt")
        
        # Determine dataset type based on dataset name
        dataset_type = DATASET_NAMES_DICT[self.dataset_name]
        
        # Initialize ObjectiveScorer with all necessary parameters
        scorer = ObjectiveScorer(
            use_detailed_rubric=True,  # Use detailed rubrics for better scoring
            dataset_type=dataset_type,
            use_api=use_api,
            model_name=self.scorer_model_name,
            device=self.device,
            max_length=4096,  # Use a reasonable default
            load_quantized=not use_api,  # Use quantization for local models
            # cache_file="./custom_rubrics_cache.json",
            cache_dir=None,
            save_dir=self.output_dir  # Directory for failure logs
        )
        
        self.logger.info(f"Initialized ObjectiveScorer with model: {self.scorer_model_name}")
        self.logger.info(f"  - Use API: {use_api}")
        self.logger.info(f"  - Dataset type: {dataset_type}")
        self.logger.info(f"  - Device: {self.device}")
        
        return scorer
    
    def _propose_objectives_with_api(
        self,
        prompt: str,
        num_objectives: int = 3
    ) -> List[str]:
        """
        Propose objectives using API model.
        
        Args:
            prompt: The prompt to send to the API
            num_objectives: Number of objectives to request
            
        Returns:
            List of proposed objective descriptions
        """
        if not self.proposer_client:
            raise ValueError("API client not initialized")
        
        try:
            response = self.proposer_client.chat.completions.create(
                model=self.proposer_model if self.proposer_model.startswith("gpt") else "gpt-4o-mini",
                messages=[
                    {"role": "system", "content": f"You are an expert at analyzing model behavior and identifying training objectives. Propose exactly {num_objectives} distinct objectives."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.7,
                max_tokens=500
            )
            
            response_text = response.choices[0].message.content.strip()
            
            # Parse objectives from response
            objectives = self._parse_objectives_from_text(response_text)
            
            return objectives[:num_objectives]  # Ensure we don't return more than requested
            
        except Exception as e:
            print(f"Error calling proposer API: {e}")
            return []

    async def _async_propose_objectives_with_api(
        self,
        prompt: str,
        num_objectives: int = 3
    ) -> List[str]:
        """
        Async version of _propose_objectives_with_api for parallel discovery.

        Args:
            prompt: The prompt to send to the API
            num_objectives: Number of objectives to request

        Returns:
            List of proposed objective descriptions
        """
        if not self.async_proposer_client:
            raise ValueError("Async API client not initialized")

        try:
            response = await self.async_proposer_client.chat.completions.create(
                model=self.proposer_model if self.proposer_model.startswith("gpt") else "gpt-4o-mini",
                messages=[
                    {"role": "system", "content": f"You are an expert at analyzing model behavior and identifying training objectives. Propose exactly {num_objectives} distinct objectives."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.7,
                max_tokens=500
            )

            response_text = response.choices[0].message.content.strip()

            # Parse objectives from response
            objectives = self._parse_objectives_from_text(response_text)

            return objectives[:num_objectives]

        except Exception as e:
            print(f"Error calling async proposer API: {e}")
            return []

    def _propose_objectives_with_local_model(
        self,
        prompt: str,
        num_objectives: int = 3
    ) -> List[str]:
        """
        Propose objectives using local model.

        Args:
            prompt: The prompt for the local model
            num_objectives: Number of objectives to request

        Returns:
            List of proposed objective descriptions
        """
        if not self.proposer_local_model:
            raise ValueError("Local model not initialized")

        # Add system prompt for better instruction following
        system_prompt = f"You are an expert at analyzing model behavior and identifying training objectives. Propose exactly {num_objectives} distinct objectives."

        # Generate response using standardized function
        generated_text = generate_huggingface_response(
            model=self.proposer_local_model,
            tokenizer=self.proposer_tokenizer,
            prompt=prompt,
            system_prompt=system_prompt,
            max_new_tokens=1024,
            temperature=0.7,
            top_p=0.9
        )

        # Parse objectives
        objectives = self._parse_objectives_from_text(generated_text)

        return objectives[:num_objectives]
    
    def _parse_objectives_from_text(self, text: str) -> List[str]:
        """
        Parse objective descriptions from generated text.
        
        Args:
            text: Generated text containing objectives
            
        Returns:
            List of objective descriptions
        """
        objectives = []
        
        # Try to find numbered objectives (1., 2., etc.)
        # Updated pattern to better handle line breaks and various formats
        numbered_pattern = r'\d+\.\s*([^\n]+?)(?:\n|$)'
        numbered_matches = re.findall(numbered_pattern, text, re.MULTILINE)
        
        if numbered_matches:
            objectives = [match.strip() for match in numbered_matches]
        else:
            # Try bullet points
            bullet_pattern = r'[-•*]\s*([^-•*\n]+)'
            bullet_matches = re.findall(bullet_pattern, text)
            
            if bullet_matches:
                objectives = [match.strip() for match in bullet_matches]
            else:
                # Fall back to splitting by newlines
                lines = text.strip().split('\n')
                objectives = [line.strip() for line in lines if line.strip() and len(line.strip()) > 10]
        
        # Clean up objectives
        cleaned_objectives = []
        for obj in objectives:
            # Remove quotes if present
            obj = obj.strip('"\'')
            # Remove trailing whitespace and punctuation
            obj = obj.rstrip()
            # Remove leading number and period if still present
            obj = re.sub(r'^\d+\.\s*', '', obj)
            # Only keep if it's a reasonable objective description
            if len(obj) > 5 and len(obj) < 500:
                cleaned_objectives.append(obj)
        
        return cleaned_objectives
    
    def _generate_responses_for_prompt(
        self,
        prompt: Union[str, List[Dict[str, str]]],
        max_new_tokens: int = 128
    ) -> List[str]:
        """
        Generate responses from each model in the sequence for a given prompt.
        
        Args:
            prompt: Input prompt (string or list of message dicts with 'role' and 'content')
            max_new_tokens: Maximum tokens to generate
            
        Returns:
            List of responses (one per model in sequence)
        """
        # Create cache key - use string representation for message lists
        cache_key = f"{str(prompt)}_{max_new_tokens}" if isinstance(prompt, list) else f"{prompt}_{max_new_tokens}"
        if cache_key in self.model_outputs_cache:
            return self.model_outputs_cache[cache_key]
        
        responses = []
        
        for model_path in self.model_sequence:
            # Check if this is an adapter model
            adapter_config_path = os.path.join(model_path, 'adapter_config.json')
            is_adapter = os.path.exists(adapter_config_path)
            
            if is_adapter:
                print(f"Loading adapter model from: {model_path}")
                # Load adapter config to get base model
                with open(adapter_config_path, 'r') as f:
                    adapter_config = json.load(f)
                base_model_name = adapter_config.get('base_model_name_or_path', 'meta-llama/Llama-3.1-8B')
                
                # Check if tokenizer files exist in the adapter directory
                tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
                special_tokens_map_path = os.path.join(model_path, 'special_tokens_map.json')

                # If tokenizer was saved with the adapter, use it to ensure compatibility
                if os.path.exists(tokenizer_config_path):
                    tokenizer = AutoTokenizer.from_pretrained(
                        model_path,
                        trust_remote_code=True
                    )
                else:
                    # Load tokenizer from base model
                    tokenizer = AutoTokenizer.from_pretrained(
                        base_model_name,
                        trust_remote_code=True
                    )
                    
                # Check if special tokens were added during training
                # if os.path.exists(special_tokens_map_path):
                #     with open(special_tokens_map_path, 'r') as f:
                #         special_tokens = json.load(f)
                #     # Apply the same special tokens that were used during training
                #     tokenizer.add_special_tokens(special_tokens)

                if tokenizer.pad_token is None:
                    # Only add padding token if not already present
                    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
                
                # Load base model with quantization
                bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.bfloat16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type='nf4'
                )
                base_model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    quantization_config=bnb_config,
                    torch_dtype=torch.bfloat16,
                    device_map=self.device,
                    trust_remote_code=True
                )
                
                if len(tokenizer) != base_model.config.vocab_size:
                    print(f"Resizing model embeddings from {base_model.config.vocab_size} to {len(tokenizer)}")
                    base_model.resize_token_embeddings(len(tokenizer))
                
                # Apply adapter
                model = PeftModel.from_pretrained(base_model, model_path)
            else:
                print(f"Loading full model from: {model_path}")
                # Load full model
                tokenizer = AutoTokenizer.from_pretrained(
                    model_path,
                    trust_remote_code=True
                )
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map=self.device
                )
            
            # Set pad token if needed
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            # Generate response
            # Handle both string prompts and message lists
            if isinstance(prompt, list):
                # Apply chat template for message format
                input_ids = tokenizer.apply_chat_template(
                    prompt,
                    padding=False,
                    add_generation_prompt=True,
                    truncation=True,
                    max_length=1024
                )
                inputs = {'input_ids': torch.tensor([input_ids], device=model.device)}
            else:
                # Handle string prompt as before
                inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
                inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=tokenizer.eos_token_id
                )
            
            # Decode response
            generated_text = tokenizer.decode(
                outputs[0][inputs['input_ids'].shape[1]:],
                skip_special_tokens=True
            )

            responses.append(generated_text)
            
            # Clean up memory
            del model
            del tokenizer
            torch.cuda.empty_cache()
        
        # Cache the results
        self.model_outputs_cache[cache_key] = responses
        
        return responses
    
    def _verify_objective(
        self,
        objective_description: str,
        verification_samples: Optional[List[Dict[str, str]]] = None,
        sample_size: Optional[int] = None,
        iteration_candidates: Optional[List[str]] = None
    ) -> Tuple[bool, Dict[str, Any]]:
        """
        Verify if an objective satisfies both criteria:
        1. Human-interpretability
        2. Predictable trend

        Args:
            objective_description: Natural language description of the objective
            verification_samples: Optional samples to use for verification
            sample_size: Number of samples to use if verification_samples not provided
            iteration_candidates: Optional list of candidate objectives from current iteration
                                  (used for prompt optimization if interpretability fails)

        Returns:
            Tuple of (is_valid, verification_details)
        """
        if self.logger:
            self.logger.info("\n" + "="*60)
            self.logger.info("OBJECTIVE VERIFICATION")
            self.logger.info("="*60)
            self.logger.info(f"Objective: {objective_description}")
            self.logger.info("")
        if sample_size is None:
            sample_size = self.verification_sample_size
            
        if verification_samples is None:
            # Sample from dataset
            verification_samples = random.sample(
                self.dataset,
                min(sample_size, len(self.dataset))
            )
        
        verification_details = {
            'objective': objective_description,
            'interpretable': False,
            'follows_trend': False,
            'interpretability_score': None,
            'trend_type': None,
            'trend_error': float('inf')
        }
        
        # 1. Check human-interpretability
        print(f"\nVerifying interpretability for: {objective_description[:50]}...")
        if self.logger:
            self.logger.info("-" * 40)
            self.logger.info("INTERPRETABILITY VERIFICATION")
            self.logger.info(f"Epsilon threshold: {self.verifier_epsilon_interpretable}")
        
        interpretability_verifier = HumanInterpretableVerifier(
            objective_description=objective_description,
            objective_model=self.scorer_model,
            human_models=self.human_scorer_models,
            model_sequence=self.model_sequence,
            epsilon=self.verifier_epsilon_interpretable,
            output_dir=self.output_dir,
            max_concurrent=self.max_concurrent
        )

        try:
            avg_difference, is_interpretable, score_details = interpretability_verifier.compute_alignment(
                dataset=verification_samples,
                dataset_type=DATASET_NAMES_DICT[self.dataset_name],
                use_provided_responses=False
            )

            verification_details['interpretable'] = is_interpretable
            verification_details['interpretability_score'] = avg_difference
            verification_details['interpretability_details'] = score_details

            if self.logger:
                self.logger.info(f"Average difference: {avg_difference:.4f}")
                self.logger.info(f"Average objective score: {score_details.get('avg_objective_score', 0):.4f}")
                self.logger.info(f"Average human score: {score_details.get('avg_human_score', 0):.4f}")

                # Log per-model human scores if available
                if 'per_human_model_stats' in score_details:
                    self.logger.info("Per-model human score statistics:")
                    for model_name, stats in score_details['per_human_model_stats'].items():
                        self.logger.info(f"  {model_name}:")
                        self.logger.info(f"    Average: {stats['avg']:.4f}")
                        self.logger.info(f"    Std Dev: {stats['std']:.4f}")
                        self.logger.info(f"    Min: {stats['min']:.4f}")
                        self.logger.info(f"    Max: {stats['max']:.4f}")

                self.logger.info(f"Is interpretable: {is_interpretable}")
                self.logger.info(f"Result: {'PASSED' if is_interpretable else 'FAILED'}")
            
            if not is_interpretable:
                self.discovery_stats['verification_failures']['interpretability'] += 1

                # Prompt optimization if enabled and candidates provided
                if self.use_prompt_optimization and iteration_candidates:
                    self.discovery_prompt, _ = optimize_discovery_prompt(
                        current_prompt=self.discovery_prompt,
                        objectives=list(iteration_candidates),
                        verification_samples=verification_samples,
                        model_sequence=self.model_sequence,
                        scorer_model=self.scorer_model_name,
                        human_scorer_models=self.human_scorer_models,
                        dataset_type=DATASET_NAMES_DICT[self.dataset_name],
                        epsilon=self.verifier_epsilon_interpretable,
                        optimizer_model=self.prompt_optimization_model,
                        logger=self.logger
                    )

                return False, verification_details

        except Exception as e:
            print(f"Error during interpretability verification: {e}")
            self.discovery_stats['verification_failures']['interpretability'] += 1
            return False, verification_details
        
        # 2. Check predictable trend
        if self.verifier_epsilon_trend == -1.0:
            # Skip trend verification if epsilon is -1.0 (for testing purposes)
            if self.logger:
                self.logger.info("Skipping trend verification (epsilon set to -1.0)")
            verification_details['follows_trend'] = True
            verification_details['trend_type'] = 'linear'
            verification_details['trend_error'] = 0.0
            return True, verification_details

        print(f"Verifying trend for: {objective_description[:50]}...")
        if self.logger:
            self.logger.info("-" * 40)
            self.logger.info("TREND VERIFICATION")
            self.logger.info(f"Epsilon threshold: {self.verifier_epsilon_trend}")
        
        trend_verifier = PredictableTrendVerifier(
            objective_description=objective_description,
            objective_model=self.scorer_model,
            model_sequence=self.model_sequence,
            epsilon=self.verifier_epsilon_trend,
            save_dir=self.output_dir,
            max_concurrent=self.max_concurrent
        )
        
        try:
            # Compute V scores
            v_scores = trend_verifier.compute_v_scores(
                dataset=verification_samples,
                sample_size=sample_size
            )

            # Fit trend and check criteria
            best_trend, params, param_dict, avg_error = trend_verifier.fit_optimal_trend(v_scores)

            is_predictable = trend_verifier.check_predictable_trend_criteria(avg_error)
            
            verification_details['follows_trend'] = is_predictable
            verification_details['trend_type'] = best_trend
            verification_details['trend_error'] = avg_error
            
            if self.logger:
                self.logger.info(f"V-scores: {[f'{v:.4f}' for v in v_scores]}")
                self.logger.info(f"Best trend type: {best_trend}")
                self.logger.info(f"Trend parameters: {param_dict}")
                self.logger.info(f"Average error: {avg_error:.4f}")
                self.logger.info(f"Is predictable: {is_predictable}")
                self.logger.info(f"Result: {'PASSED' if is_predictable else 'FAILED'}")
            
            if not is_predictable:
                self.discovery_stats['verification_failures']['trend'] += 1
                return False, verification_details
                
        except Exception as e:
            print(f"Error during trend verification: {e}")
            self.discovery_stats['verification_failures']['trend'] += 1
            return False, verification_details
        
        # Both criteria satisfied
        if self.logger:
            self.logger.info("-" * 40)
            self.logger.info("FINAL VERIFICATION RESULT: PASSED")
            self.logger.info(f"  Interpretability Score: {verification_details['interpretability_score']:.4f}")
            self.logger.info(f"  Trend Type: {verification_details['trend_type']}")
            self.logger.info(f"  Trend Error: {verification_details['trend_error']:.4f}")
            
        self.discovery_prompt = OBJECTIVE_DISCOVERY_PROMPT  # Reset to default prompt each iteration
        return True, verification_details

    def _group_calculate_residuals_for_samples(
        self,
        samples: List[Dict[str, str]],
        current_objectives: List[str]
    ) -> Tuple[Dict[str, float], Dict[str, List[str]]]:
        """
        Calculate average residuals for given samples using current objectives with group scoring.

        This method uses group scoring where all responses (from train and test models) are
        scored together for better calibration, similar to the implementation in calc_objectives_fit.py.

        Args:
            samples: List of input samples
            current_objectives: Current set of discovered objectives

        Returns:
            Tuple of:
            - Dictionary mapping sample inputs to average residuals
            - Dictionary mapping sample inputs to LIST of test model responses (handles multiple test models)
        """
        start_time = time.time()
        print(f"Calculating residuals with GROUP SCORING for {len(samples)} samples...")

        # Create ObjectivesFit instance with current objectives - enable batching and group scoring
        obj_fit_calc = self.ObjectivesFit(
            dataset=samples,
            model_sequence=self.model_sequence,
            ground_truth_objective=self.ground_truth_reward,
            combination_function_type=self.combination_function_type,
            combination_function_params=self.combination_function_params,
            num_samples=len(samples),
            train_test_split_idx=self.train_test_split_idx,
            scorer_model=self.scorer_model,
            device=self.device,
            cache_responses=True,
            save_dir=self.output_dir,
            use_different_prompts=False,
            dataset_type=DATASET_NAMES_DICT[self.dataset_name],
            use_detailed_rubric=True,
            batching=True,  # Enable batched generation
            group_scoring=True,  # Enable group scoring
            batch_size=8,   # Use reasonable batch size
            model_cache_size=1,  # Keep models in memory
            normalize_scores=True,  # Normalize scores to [0, 1] for objective discovery
            logger=self.logger,
            max_concurrent=self.max_concurrent
        )

        residuals_by_sample = {}
        test_responses_by_sample = {}  # Will store LIST of responses from ALL test models for each sample

        # Split models into train and test
        train_models = self.model_sequence[:self.train_test_split_idx]
        test_models = self.model_sequence[self.train_test_split_idx:]
        all_models = self.model_sequence  # Use all models for group context

        # Collect all input texts for processing
        input_texts = [sample['input'] for sample in samples]

        # First, collect ALL responses from both train and test models for group scoring
        print("Collecting responses from all models for group scoring context...")
        all_responses_by_sample = [[] for _ in samples]  # List of lists: [sample_idx][model_idx]
        all_responses_by_input_key = {}  # Dict: input_key -> list of responses from all models

        for model_idx, model_path in enumerate(all_models):
            is_test_model = model_idx >= self.train_test_split_idx
            model_type = "test" if is_test_model else "train"
            local_idx = model_idx - self.train_test_split_idx if is_test_model else model_idx

            print(f"Collecting responses from {model_type} model {local_idx + 1}/{len(test_models) if is_test_model else len(train_models)}: {model_path}")

            # Generate all responses at once for this model
            responses = obj_fit_calc._generate_response_batched(model_path, input_texts, batch_size=8)

            # Store responses organized by sample
            for sample_idx, response in enumerate(responses):
                all_responses_by_sample[sample_idx].append(response)

                # Store in all_responses_by_input_key for reuse
                input_text = samples[sample_idx]['input']
                input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
                if input_key not in all_responses_by_input_key:
                    all_responses_by_input_key[input_key] = []
                if model_idx == 0:  # Initialize list on first model
                    all_responses_by_input_key[input_key] = []
                # Ensure we're appending in order (model_idx order)
                if len(all_responses_by_input_key[input_key]) == model_idx:
                    all_responses_by_input_key[input_key].append(response)

                # If this is a test model, store in test_responses_by_sample
                # Properly handle multiple test models by appending to a list
                if is_test_model:
                    # Initialize list for this input_key if not present
                    if input_key not in test_responses_by_sample:
                        test_responses_by_sample[input_key] = []

                    # Append this test model's response to the list
                    test_responses_by_sample[input_key].append(response)

        print(f"Collected responses from all {len(all_models)} models ({len(train_models)} train, {len(test_models)} test)")

        # If no objectives yet, residuals are just ground truth squared (using group scoring)
        if not current_objectives:
            print("No objectives yet - calculating ground truth residuals with batch parallel scoring...")

            # Build formatted inputs for all samples
            all_formatted_inputs = [
                apply_chat_template_to_prompt(self.model_sequence[0], sample['input'])
                for sample in samples
            ]

            # Get ground truth for ALL samples in batch (enables parallelization)
            all_ground_truths_by_sample, _ = obj_fit_calc._get_group_ground_truth_both(
                all_formatted_inputs,
                all_responses_by_sample
            )

            # Calculate residuals for each sample (test models only)
            for sample_idx, sample in enumerate(samples):
                input_key = all_formatted_inputs[sample_idx]
                ground_truths = all_ground_truths_by_sample[sample_idx]

                # Calculate residuals only for test models (gt^2 when no objectives)
                test_residuals = [
                    ground_truths[model_idx] ** 2
                    for model_idx in range(self.train_test_split_idx, len(all_models))
                ]

                # Average residuals across test models
                avg_residual = sum(test_residuals) / len(test_residuals) if test_residuals else 0.0
                residuals_by_sample[input_key] = avg_residual

            # OLD: Sequential per-sample ground truth (replaced by batch parallel above)
            # residuals_per_sample = {i: [] for i in range(len(samples))}
            # for sample_idx, sample in enumerate(samples):
            #     input_text = sample['input']
            #     input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
            #     all_responses = all_responses_by_sample[sample_idx]
            #     all_ground_truths = obj_fit_calc._get_group_ground_truth(
            #         input_key, all_responses, denormalize_scores=False
            #     )
            #     for model_idx in range(self.train_test_split_idx, len(all_models)):
            #         gt = all_ground_truths[model_idx]
            #         residuals_per_sample[sample_idx].append(gt ** 2)
            # for sample_idx, sample in enumerate(samples):
            #     input_text = sample['input']
            #     input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
            #     avg_residual = sum(residuals_per_sample[sample_idx]) / len(residuals_per_sample[sample_idx]) \
            #         if residuals_per_sample[sample_idx] else 0.0
            #     residuals_by_sample[input_key] = avg_residual

        else:
            print(f"Fitting combination function with {len(current_objectives)} objectives using batch parallel scoring...")

            # Build formatted inputs for all samples
            all_formatted_inputs = [
                apply_chat_template_to_prompt(self.model_sequence[0], sample['input'])
                for sample in samples
            ]

            # Phase 1: Get training features and targets using batch parallel scoring
            # Reuse _process_group_scoring_async which parallelizes all (sample, objective) pairs
            train_features, train_targets, _ = obj_fit_calc._process_group_scoring_async(
                all_formatted_inputs,
                all_responses_by_sample,
                current_objectives,
                num_samples_to_log=0,  # No logging needed here
                num_train_models=self.train_test_split_idx
            )

            print(f"Collected {len(train_features)} training samples for combination function")

            # Fit combination function
            obj_fit_calc.combination_function, obj_fit_calc.obj_coefficients = obj_fit_calc._fit_combination_function(
                current_objectives,
                train_features,
                train_targets
            )

            # Log coefficients if linear regression was used
            if obj_fit_calc.obj_coefficients is not None and self.logger:
                self.logger.info("\n--- Group Scoring Obj Coefficients ---")
                self.logger.info(f"Intercept: {obj_fit_calc.obj_coefficients['intercept']:.4f}")
                for obj_name, coef in obj_fit_calc.obj_coefficients['coefficients'].items():
                    self.logger.info(f"  {obj_name}: {coef:.4f}")
                self.logger.info("--------------------------------------")

            # Phase 2: Calculate test residuals using batch parallel scoring
            print("Calculating residuals on test models with batch parallel scoring...")

            # Build test model indices
            test_model_indices = list(range(self.train_test_split_idx, len(all_models)))

            # Use batch residual calculation (parallelizes all scoring)
            batch_results = obj_fit_calc._calculate_residuals_group_batch(
                all_formatted_inputs,
                all_responses_by_sample,
                current_objectives,
                test_model_indices
            )

            # Results are ordered as: [(sample_0, test_0), (sample_0, test_1), ..., (sample_1, test_0), ...]
            # Average residuals across test models for each sample
            num_test_models = len(test_model_indices)
            result_idx = 0

            for sample_idx in range(len(samples)):
                input_key = all_formatted_inputs[sample_idx]

                # Collect residuals for all test models for this sample
                sample_residuals = [
                    batch_results[result_idx + t][0]  # [0] extracts residual from (residual, predicted, ground_truth)
                    for t in range(num_test_models)
                ]
                result_idx += num_test_models

                # Average residuals across test models
                avg_residual = sum(sample_residuals) / len(sample_residuals) if sample_residuals else 0.0
                residuals_by_sample[input_key] = avg_residual

            # OLD: Sequential training phase (replaced by _process_group_scoring_async above)
            # train_features = []
            # train_targets = []
            # for sample_idx, sample in enumerate(samples):
            #     input_text = sample['input']
            #     input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
            #     all_responses = all_responses_by_sample[sample_idx]
            #     all_scores_by_objective = {}
            #     for obj_idx, obj_desc in enumerate(current_objectives):
            #         all_scores = obj_fit_calc._group_score_with_objective(
            #             input_key, all_responses, obj_desc
            #         )
            #         all_scores_by_objective[f"obj_{obj_idx}"] = all_scores
            #     all_ground_truths = obj_fit_calc._get_group_ground_truth(
            #         input_key, all_responses, denormalize_scores=False
            #     )
            #     for model_idx in range(self.train_test_split_idx):
            #         objective_scores = {}
            #         for obj_idx in range(len(current_objectives)):
            #             objective_scores[f"obj_{obj_idx}"] = all_scores_by_objective[f"obj_{obj_idx}"][model_idx]
            #         train_features.append(objective_scores)
            #         train_targets.append(all_ground_truths[model_idx])

            # OLD: Sequential test residual calculation (replaced by _calculate_residuals_group_batch above)
            # residuals_per_sample = {i: [] for i in range(len(samples))}
            # for sample_idx, sample in enumerate(samples):
            #     input_text = sample['input']
            #     input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
            #     all_responses = all_responses_by_sample[sample_idx]
            #     for model_idx in range(self.train_test_split_idx, len(all_models)):
            #         response = all_responses[model_idx]
            #         residual, predicted, ground_truth = obj_fit_calc._calculate_residual_group(
            #             input_key, response, all_responses, current_objectives, model_idx
            #         )
            #         residuals_per_sample[sample_idx].append(residual)
            # for sample_idx, sample in enumerate(samples):
            #     input_text = sample['input']
            #     input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
            #     avg_residual = sum(residuals_per_sample[sample_idx]) / len(residuals_per_sample[sample_idx]) \
            #         if residuals_per_sample[sample_idx] else 0.0
            #     residuals_by_sample[input_key] = avg_residual

        # Cleanup
        obj_fit_calc.cleanup_model_cache()

        # Verify that test_responses_by_sample contains lists with correct number of responses
        if test_responses_by_sample:
            sample_key = next(iter(test_responses_by_sample))
            num_responses = len(test_responses_by_sample[sample_key])
            print(f"Each sample has {num_responses} test model responses (expected: {len(test_models)})")

        print(f"Group scoring residual calculation complete for {len(samples)} samples")
        print(f"Total time taken for _group_calculate_residuals_for_samples: {time.time() - start_time:.2f} seconds")
        return residuals_by_sample, test_responses_by_sample, all_responses_by_input_key
    
    def calculate_final_obj_error(
        self,
        discovered_objectives: List[str],
        ground_truth_reward,
        num_samples_eval: int,
        train_test_split_idx: Optional[int],
        combination_function_type: str,
        combination_function_params: Optional[Dict[str, Any]],
        save_dir: Optional[str] = None
    ) -> Tuple[Optional[float], Optional[str]]:
        """
        Calculate final Obj-Error for discovered objectives and save the fitted reward combiner.

        Args:
            discovered_objectives: List of discovered objectives
            ground_truth_reward: RewardFunction instance for ground truth
            num_samples_eval: Number of samples for evaluation
            train_test_split_idx: Index to split model sequence for train/test
            combination_function_type: Type of combination function
            combination_function_params: Parameters for combination function
            save_dir: Optional directory to save the fitted reward combiner

        Returns:
            Tuple of (final_obj_error, combiner_save_path)
                - final_obj_error: The calculated Obj-Error or None if error
                - combiner_save_path: Path where combiner was saved or None
        """
        if not ground_truth_reward or len(discovered_objectives) == 0:
            if not ground_truth_reward:
                print("\n(Skipping final Obj-Error calculation - no ground truth reward provided)")
            if self.logger and not ground_truth_reward:
                self.logger.info("Skipping final Obj-Error - no ground truth reward")
            return None, None

        print("\n=== CALCULATING FINAL OBJ-ERROR ===")

        if self.logger:
            self.logger.info("\n" + "="*60)
            self.logger.info("FINAL OBJ-ERROR CALCULATION")
            self.logger.info(f"Using num_samples_eval: {num_samples_eval}")
            self.logger.info("="*60)

        try:
            # Sample random prompts for evaluation
            eval_samples = random.sample(
                self.dataset,
                min(num_samples_eval, len(self.dataset))
            )

            # Initialize ObjectivesFit for calculation
            from src.calc_objectives_fit import ObjectivesFit
            obj_fit_calc = ObjectivesFit(
                dataset=eval_samples,
                model_sequence=self.model_sequence,
                ground_truth_objective=ground_truth_reward,
                combination_function_type=combination_function_type,
                combination_function_params=combination_function_params or {},
                num_samples=len(eval_samples),
                train_test_split_idx=train_test_split_idx,
                scorer_model=self.scorer_model,
                cache_responses=True,
                use_different_prompts=False,
                save_dir=self.output_dir,
                dataset_type=DATASET_NAMES_DICT[self.dataset_name],
                use_detailed_rubric=True,
                batching=True,
                batch_size=8,
                model_cache_size=1,
                normalize_scores=True,
                logger=self.logger,
                max_concurrent=self.max_concurrent
            )

            # Calculate Obj-Error
            final_obj_error = obj_fit_calc.calculate(discovered_objectives)

            print(f"Final Obj-Error: {final_obj_error:.6f}")
            print(f"(Lower is better - measures difference from ground truth rewards)")

            # Log the training data that was collected and used for fitting
            if self.logger and hasattr(obj_fit_calc, 'train_features') and obj_fit_calc.train_features is not None:
                self.logger.info("\n" + "="*60)
                self.logger.info("TRAINING DATA USED FOR FITTING")
                self.logger.info("="*60)

                # Log normalized training features and targets
                self.logger.info(f"\nNumber of training samples: {len(obj_fit_calc.train_features)}")
                self.logger.info("\n--- Normalized Training Features (0-1 scale) ---")
                for i, features in enumerate(obj_fit_calc.train_features[:5]):  # Show first 5 samples
                    self.logger.info(f"Sample {i+1}: {features}")
                if len(obj_fit_calc.train_features) > 5:
                    self.logger.info(f"... ({len(obj_fit_calc.train_features) - 5} more samples)")

                self.logger.info("\n--- Normalized Training Targets (0-1 scale) ---")
                for i, target in enumerate(obj_fit_calc.train_targets[:5]):  # Show first 5 targets
                    self.logger.info(f"Sample {i+1}: {target:.4f}")
                if len(obj_fit_calc.train_targets) > 5:
                    self.logger.info(f"... ({len(obj_fit_calc.train_targets) - 5} more samples)")

                # Log unnormalized training features and targets
                if hasattr(obj_fit_calc, 'unnormalized_train_features') and obj_fit_calc.unnormalized_train_features is not None:
                    self.logger.info("\n--- Unnormalized Training Features (1-10 scale) ---")
                    for i, features in enumerate(obj_fit_calc.unnormalized_train_features[:20]):  # Show first 20 samples
                        self.logger.info(f"Sample {i+1}: {features}")
                    if len(obj_fit_calc.unnormalized_train_features) > 5:
                        self.logger.info(f"... ({len(obj_fit_calc.unnormalized_train_features) - 20} more samples)")

                if hasattr(obj_fit_calc, 'denormalized_train_targets') and obj_fit_calc.denormalized_train_targets is not None:
                    self.logger.info("\n--- Denormalized Training Targets (1-10 scale) ---")
                    for i, target in enumerate(obj_fit_calc.denormalized_train_targets[:20]):  # Show first 20 targets
                        self.logger.info(f"Sample {i+1}: {target:.4f}")
                    if len(obj_fit_calc.denormalized_train_targets) > 5:
                        self.logger.info(f"... ({len(obj_fit_calc.denormalized_train_targets) - 20} more samples)")

                # Log summary statistics
                self.logger.info("\n--- Training Data Statistics ---")
                import numpy as np

                # Normalized targets stats
                normalized_mean = np.mean(obj_fit_calc.train_targets)
                normalized_std = np.std(obj_fit_calc.train_targets)
                self.logger.info(f"Normalized targets - Mean: {normalized_mean:.4f}, Std: {normalized_std:.4f}")

                # Denormalized targets stats
                if hasattr(obj_fit_calc, 'denormalized_train_targets') and obj_fit_calc.denormalized_train_targets is not None:
                    denorm_mean = np.mean(obj_fit_calc.denormalized_train_targets)
                    denorm_std = np.std(obj_fit_calc.denormalized_train_targets)
                    self.logger.info(f"Denormalized targets - Mean: {denorm_mean:.4f}, Std: {denorm_std:.4f}")

                self.logger.info("="*60 + "\n")

            if self.logger:
                self.logger.info(f"\nFinal Obj-Error with {len(discovered_objectives)} objectives: {final_obj_error:.6f}")
                self.logger.info("Discovered objectives:")
                for idx, obj in enumerate(discovered_objectives, 1):
                    self.logger.info(f"  {idx}. {obj}")

                # Log linear regression coefficients if applicable
                if combination_function_type == 'linear_regression':
                    if hasattr(obj_fit_calc, 'denormalized_obj_coefficients') and obj_fit_calc.denormalized_obj_coefficients is not None:
                        self.logger.info("\n--- Final Evaluation Denormalized Obj Coefficients ---")
                        self.logger.info(f"Intercept: {obj_fit_calc.denormalized_obj_coefficients['intercept']:.4f}")
                        for obj_name, coef in obj_fit_calc.denormalized_obj_coefficients['coefficients'].items():
                            self.logger.info(f"  {obj_name}: {coef:.4f}")
                        self.logger.info("--------------------------------------------------------")

            # Log detailed reward combiner analysis if logger is available
            if self.logger and hasattr(obj_fit_calc, 'denormalized_combination_function'):
                log_test_reward_combiner(self.logger, obj_fit_calc.denormalized_combination_function)

            # Save the fitted reward combiner (denormalized version) if directory provided
            combiner_save_path = None
            if save_dir and hasattr(obj_fit_calc, 'denormalized_combination_function'):
                combiner_save_path = os.path.join(save_dir, "fitted_reward_combiner")
                obj_fit_calc.denormalized_combination_function.save(combiner_save_path)
                print(f"Saved fitted denormalized reward combiner to: {combiner_save_path}")
                if self.logger:
                    self.logger.info(f"Saved denormalized reward combiner to: {combiner_save_path}")

            # Save unnormalized training data for later inspection
            if save_dir and hasattr(obj_fit_calc, 'unnormalized_train_features') and hasattr(obj_fit_calc, 'denormalized_train_targets'):
                training_data_path = os.path.join(save_dir, "reward_combiner_training_data.pkl")
                training_data = {
                    'unnormalized_train_features': obj_fit_calc.unnormalized_train_features,
                    'denormalized_train_targets': obj_fit_calc.denormalized_train_targets
                }
                with open(training_data_path, 'wb') as f:
                    pickle.dump(training_data, f)
                print(f"Saved training data to: {training_data_path}")
                if self.logger:
                    self.logger.info(f"Saved unnormalized training data to: {training_data_path}")

            # Cleanup
            obj_fit_calc.cleanup_model_cache()

            return final_obj_error, combiner_save_path

        except Exception as e:
            print(f"Error calculating final Obj-Error: {e}")
            if self.logger:
                self.logger.error(f"Failed to calculate final Obj-Error: {e}")
            return None, None

    def _format_single_trajectory(
        self,
        prompt: str,
        responses: List[str],
        trajectory_num: int
    ) -> str:
        """Format a single trajectory for inclusion in batch prompt."""
        trajectory_text = f"==== TRAJECTORY {trajectory_num} ====\n"
        trajectory_text += f"Input Prompt: {prompt}\n\n"
        
        for i, response in enumerate(responses):
            trajectory_text += f"Model Iteration {i+1} Response:\n{response}\n\n"
        
        return trajectory_text
    
    def _create_batch_discovery_prompt(
        self,
        batch_trajectories: List[str],
        existing_objectives: List[str]
    ) -> str:
        """Create prompt for proposer to discover objectives from multiple trajectories."""
        # Combine all trajectories
        trajectories_text = "\n".join(batch_trajectories)
        
        # Prepare existing objectives section if needed
        existing_objectives_section = ""
        if existing_objectives:
            existing_obj_list = "\n".join([f"- {obj}" for obj in existing_objectives])
            existing_objectives_section = OBJECTIVE_DISCOVERY_WITH_EXISTING_PROMPT.format(
                existing_objectives=existing_obj_list
            )
        
        # Use the modifiable discovery prompt (may be updated by prompt optimization)
        full_prompt = self.discovery_prompt.format(
            trajectory_count=len(batch_trajectories),
            trajectories=trajectories_text,
            # num_objectives=self.objectives_per_trajectory * len(batch_trajectories),
            num_objectives=self.objectives_per_trajectory,
            existing_objectives_section=existing_objectives_section
        )

        return full_prompt

    @abstractmethod
    def obtain_objectives(self) -> Tuple[List[str], Dict[str, Any]]:
        """
        Main method to discover objectives.
        Must be implemented by subclasses.

        Returns:
            Tuple of (discovered_objectives, statistics)
        """
        pass


class RandomObjectivesDiscovery(BaseObjectivesDiscovery):
    """
    Random baseline for objective discovery.
    
    This approach:
    1. Randomly samples prompts from the dataset
    2. Generates responses from each model in the sequence
    3. Asks a proposer model to identify objectives from the trajectories
    4. Verifies objectives against the two criteria
    5. Repeats until k valid objectives are found
    """
    
    def __init__(
        self,
        dataset: Union[str, List[Dict[str, str]]],
        model_sequence: List[str],
        k: int = 4,
        samples_per_iteration: int = 5,
        objectives_per_trajectory: int = 3,
        max_iterations: int = 100,
        ground_truth_reward=None,  # Optional ground truth reward function for final evaluation
        num_samples_final_eval: int = 25,  # Number of samples for final Obj-Error calculation
        train_test_split_idx: Optional[int] = None,
        combination_function_type: str = 'linear_regression',
        combination_function_params: Optional[Dict[str, Any]] = None,
        **kwargs
    ):
        """
        Initialize the random objective discovery.

        Args:
            dataset: Dataset to sample from
            model_sequence: List of model checkpoints
            k: Number of objectives to discover
            samples_per_iteration: Number of prompts to sample per iteration
            objectives_per_trajectory: Number of objectives to request per trajectory
            max_iterations: Maximum iterations before giving up
            ground_truth_reward: Optional RewardFunction for final Obj-Error calculation
            num_samples_final_eval: Number of samples for final evaluation
            combination_function_type: Type of combination function for Obj-Error
            combination_function_params: Parameters for combination function
            **kwargs: Additional arguments for base class
        """
        super().__init__(dataset, model_sequence, k, **kwargs)

        self.samples_per_iteration = samples_per_iteration
        self.objectives_per_trajectory = objectives_per_trajectory
        self.max_iterations = max_iterations
        self.ground_truth_reward = ground_truth_reward
        self.num_samples_final_eval = num_samples_final_eval
        self.combination_function_type = combination_function_type
        self.combination_function_params = combination_function_params or {}

        self.ObjectivesFit = ObjectivesFit

        self.train_test_split_idx = train_test_split_idx
        if self.train_test_split_idx is None:
            self.train_test_split_idx = len(model_sequence) // 2

        # Track which objectives we've already seen
        self.seen_objectives = set()
    
    def _create_trajectory_prompt(
        self,
        prompt: str,
        responses: List[str],
        existing_objectives: List[str]
    ) -> str:
        """
        Create a prompt for the proposer model to analyze a trajectory.
        
        Args:
            prompt: The input prompt
            responses: List of model responses (ordered by training iteration)
            existing_objectives: Already discovered objectives to avoid
            
        Returns:
            Formatted prompt for proposer model
        """
        # Build the trajectory description
        trajectory_text = f"Input Prompt: {prompt}\n\n"
        
        for i, response in enumerate(responses):
            trajectory_text += f"Model Iteration {i+1} Response:\n{response}\n\n"
        
        # Prepare existing objectives section if needed
        existing_objectives_section = ""
        if existing_objectives:
            existing_obj_list = "\n".join([f"- {obj}" for obj in existing_objectives])
            existing_objectives_section = OBJECTIVE_DISCOVERY_WITH_EXISTING_PROMPT.format(
                existing_objectives=existing_obj_list
            )
        
        # Use the modifiable discovery prompt (may be updated by prompt optimization)
        full_prompt = self.discovery_prompt.format(
            trajectory_count=1,
            trajectories=trajectory_text,
            num_objectives=self.objectives_per_trajectory,
            existing_objectives_section=existing_objectives_section
        )

        return full_prompt

    def _calculate_residuals_for_samples(
        self,
        samples: List[Dict[str, str]],
        current_objectives: List[str]
    ) -> Dict[str, float]:  # Returns: Tuple[Dict[str, float], Dict[str, List[str]]]
        """
        Calculate average residuals for given samples using current objectives.
        
        Args:
            samples: List of input samples
            current_objectives: Current set of discovered objectives
            
        Returns:
            Dictionary mapping sample inputs to average residuals
            # Also returns: Dictionary mapping sample inputs to test model responses
        """
        print(f"Calculating residuals for {len(samples)} samples...")
        
        # Create ObjectivesFit instance with current objectives - enable batching
        obj_fit_calc = self.ObjectivesFit(
                    dataset=samples,
                    model_sequence=self.model_sequence,
                    ground_truth_objective=self.ground_truth_reward,
                    combination_function_type=self.combination_function_type,
                    combination_function_params=self.combination_function_params,
                    num_samples=len(samples),
                    train_test_split_idx=self.train_test_split_idx,
                    scorer_model=self.scorer_model,
                    cache_responses=True,
                    use_different_prompts=False,
                    save_dir=self.output_dir,
                    dataset_type=DATASET_NAMES_DICT[self.dataset_name],
                    use_detailed_rubric=True,
                    batching=True,  # Enable batched generation
                    batch_size=8,   # Use reasonable batch size
                    model_cache_size=1,  # Keep up to 3 models in memory
                    normalize_scores=True,  # Normalize scores to [0, 1] for objective discovery
                    logger=self.logger,
                    max_concurrent=self.max_concurrent
                )

        residuals_by_sample = {}
        test_responses_by_sample = {}  # Store responses from test models

        # Collect all input texts for batching
        input_texts = [sample['input'] for sample in samples]
        
        # If no objectives yet, residuals are just ground truth squared
        if not current_objectives:
            # Process each model in test set with batched generation
            test_models = self.model_sequence[self.train_test_split_idx:]
            residuals_per_sample = {i: [] for i in range(len(samples))}

            for model_path in test_models:
                print(f"Processing model: {model_path}")
                # Generate all responses at once for this model
                responses = obj_fit_calc._generate_response_batched(model_path, input_texts, batch_size=8)

                # Process each response
                for sample_idx, (sample, response) in enumerate(zip(samples, responses)):
                    input_text = sample['input']
                    # input_key = input_text[0]['content'] if isinstance(input_text, list) else input_text
                    # Properly format multi-turn conversations using chat template
                    input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

                    # Store responses from test models
                    if input_key not in test_responses_by_sample:
                        test_responses_by_sample[input_key] = []
                    test_responses_by_sample[input_key].append(response)

                    # Get ground truth
                    gt = obj_fit_calc._get_ground_truth(
                        input_key,
                        response
                    )
                    # Residual is just gt^2 when no objectives
                    residuals_per_sample[sample_idx].append(gt ** 2)

                print(f"Processed {len(samples)} samples for model")

            # Average residuals across models
            for sample_idx, sample in enumerate(samples):
                input_text = sample['input']
                avg_residual = sum(residuals_per_sample[sample_idx]) / len(residuals_per_sample[sample_idx]) if residuals_per_sample[sample_idx] else 0.0
                # residuals_by_sample[input_text[0]['content'] if isinstance(input_text, list) else input_text] = avg_residual
                # Properly format multi-turn conversations using chat template
                input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
                residuals_by_sample[input_key] = avg_residual
        else:
            # Fit combination function on training models
            train_models = self.model_sequence[:self.train_test_split_idx]
            train_features = []
            train_targets = []
            
            for model_path in tqdm(train_models, desc="Training combination function"):
                # Generate all responses at once for this model
                responses = obj_fit_calc._generate_response_batched(model_path, input_texts, batch_size=8)

                # Process each response
                for sample, response in zip(samples, responses):
                    input_text = sample['input']
                    # Properly format multi-turn conversations using chat template
                    input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

                    # Score with current objectives
                    objective_scores = {}
                    for i, obj_desc in enumerate(current_objectives):
                        score = obj_fit_calc._score_with_objective(
                            input_key,
                            response,
                            obj_desc
                        )
                        objective_scores[f"obj_{i}"] = score

                    # Get ground truth
                    gt = obj_fit_calc._get_ground_truth(
                        input_key,
                        response
                    )

                    train_features.append(objective_scores)
                    train_targets.append(gt)

                print(f"Processed {len(samples)} samples for model")
            
            # Fit combination function
            obj_fit_calc.combination_function, obj_fit_calc.obj_coefficients = obj_fit_calc._fit_combination_function(
                current_objectives,
                train_features,
                train_targets
            )

            # Log coefficients if linear regression was used
            if obj_fit_calc.obj_coefficients is not None and self.logger:
                self.logger.info("\n--- Obj Coefficients ---")
                self.logger.info(f"Intercept: {obj_fit_calc.obj_coefficients['intercept']:.4f}")
                for obj_name, coef in obj_fit_calc.obj_coefficients['coefficients'].items():
                    self.logger.info(f"  {obj_name}: {coef:.4f}")
                self.logger.info("--------------------------------------")
            
            # Calculate residuals on test models with batching
            test_models = self.model_sequence[self.train_test_split_idx:]
            residuals_per_sample = {i: [] for i in range(len(samples))}
            
            for model_path in test_models:
                print(f"Processing test model: {model_path}")
                # Generate all responses at once for this model
                responses = obj_fit_calc._generate_response_batched(model_path, input_texts, batch_size=8)

                # Calculate residuals for each response
                for sample_idx, (sample, response) in enumerate(zip(samples, responses)):
                    input_text = sample['input']
                    # input_key = input_text[0]['content'] if isinstance(input_text, list) else input_text
                    # Properly format multi-turn conversations using chat template
                    input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

                    # Store responses from test models
                    if input_key not in test_responses_by_sample:
                        test_responses_by_sample[input_key] = []
                    test_responses_by_sample[input_key].append(response)

                    residual = obj_fit_calc._calculate_residual(
                        input_key,
                        response,
                        current_objectives
                    )
                    residuals_per_sample[sample_idx].append(residual)

                print(f"Processed {len(samples)} samples for model")

            # Average residuals across models
            for sample_idx, sample in enumerate(samples):
                input_text = sample['input']
                avg_residual = sum(residuals_per_sample[sample_idx]) / len(residuals_per_sample[sample_idx]) if residuals_per_sample[sample_idx] else 0.0
                # residuals_by_sample[input_text[0]['content'] if isinstance(input_text, list) else input_text] = avg_residual
                # Properly format multi-turn conversations using chat template
                input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
                residuals_by_sample[input_key] = avg_residual
        
        # Cleanup
        obj_fit_calc.cleanup_model_cache()
        
        # return residuals_by_sample
        return residuals_by_sample, test_responses_by_sample  # Also return test responses
    
    def obtain_objectives(self) -> Tuple[List[str], Dict[str, Any]]:
        """
        Discover objectives using random sampling approach.
        
        Returns:
            Tuple of (discovered_objectives, statistics)
        """
        start_time = time.time()
        iteration = 0
        
        print(f"\nStarting Random Objective Discovery")
        print(f"Target: {self.k} valid objectives")
        print(f"Dataset size: {len(self.dataset)} samples")
        print(f"Model sequence length: {len(self.model_sequence)} models")
        print("="*60)
        
        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "STARTING RANDOM OBJECTIVES DISCOVERY".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"Configuration:")
            self.logger.info(f"  Target objectives (k): {self.k}")
            self.logger.info(f"  Dataset size: {len(self.dataset)} samples")
            self.logger.info(f"  Model sequence: {len(self.model_sequence)} models")
            self.logger.info(f"  Samples per iteration: {self.samples_per_iteration}")
            self.logger.info(f"  Objectives per trajectory: {self.objectives_per_trajectory}")
            self.logger.info(f"  Max iterations: {self.max_iterations}")
            self.logger.info(f"  Verification epsilon (interpretability): {self.verifier_epsilon_interpretable}")
            self.logger.info(f"  Verification epsilon (trend): {self.verifier_epsilon_trend}")
            self.logger.info("")

        # 1. Randomly sample prompts
        sampled_prompts = random.sample(
            self.dataset,
            min(self.samples_per_iteration, len(self.dataset))
        )
        
        if self.logger:
            self.logger.info("="*60)
            self.logger.info("PHASE 1: RANDOM SAMPLING")
            self.logger.info("="*60)
            self.logger.info(f"Sampled {len(sampled_prompts)} prompts from dataset")
            self.logger.info("\nSampled prompts:")
            for idx, sample in enumerate(sampled_prompts, 1):
                # prompt_text = sample['input'][0]['content'] if isinstance(sample['input'], list) else sample['input']
                # Properly format multi-turn conversations using chat template
                prompt_text = apply_chat_template_to_prompt(
                    self.model_sequence[0],
                    sample['input']
                )
                # Truncate for display
                prompt_display = prompt_text
                prompt_display = prompt_display.replace('\n', ' ')
                self.logger.info(f"  Sample {idx}: {prompt_display}")
            self.logger.info("")
        
        while len(self.discovered_objectives) < self.k and iteration < self.max_iterations:
            iteration += 1
            self.discovery_stats['total_iterations'] = iteration
            
            print(f"\n--- Iteration {iteration} ---")
            print(f"Current valid objectives: {len(self.discovered_objectives)}/{self.k}")
            
            if self.logger:
                self.logger.info("\n" + "*"*80)
                self.logger.info("*" + f"ITERATION {iteration}".center(78) + "*")
                self.logger.info("*"*80)
                self.logger.info(f"Progress: {len(self.discovered_objectives)}/{self.k} objectives discovered")
                if self.discovered_objectives:
                    self.logger.info("Currently discovered objectives:")
                    for idx, obj in enumerate(self.discovered_objectives, 1):
                        self.logger.info(f"  {idx}. {obj}")
                self.logger.info("")

            # 2. Generate trajectories and propose objectives
            iteration_proposals = []
            trajectories_log = []  # Store trajectory info for logging

            if self.logger:
                self.logger.info("="*60)
                self.logger.info("PHASE 2: TRAJECTORY GENERATION & OBJECTIVE DISCOVERY")
                self.logger.info("="*60)

            # Extract all prompts for batched generation
            all_prompts = [sample['input'] for sample in sampled_prompts]

            # Generate responses for all prompts from each model using batched generation
            # responses_by_model[model_idx] = list of responses for all prompts
            responses_by_model = []
            for model_idx, model_path in enumerate(self.model_sequence):
                print(f"\nGenerating responses from model {model_idx + 1}/{len(self.model_sequence)}: {model_path}")
                try:
                    model_responses = generate_responses_batched_util(
                        model_path=model_path,
                        prompts=all_prompts,
                        # max_new_tokens=128,
                        max_new_tokens=512,
                        batch_size=8,
                        temperature=0.7,
                        top_p=0.9,
                        # use_cache=True  # Enable disk-based caching
                    )
                    responses_by_model.append(model_responses)
                except Exception as e:
                    print(f"Error generating responses from model {model_path}: {e}")
                    if self.logger:
                        self.logger.error(f"Failed to generate responses from model {model_path}: {e}")
                    # Fill with empty responses to maintain structure
                    responses_by_model.append([""] * len(all_prompts))

            # Reorganize responses into per-prompt trajectories
            # all_trajectories[prompt_idx] = [response_model_0, response_model_1, ...]
            all_trajectories = []
            for prompt_idx in range(len(all_prompts)):
                trajectory = [responses_by_model[model_idx][prompt_idx] for model_idx in range(len(self.model_sequence))]
                all_trajectories.append(trajectory)

            # Process each prompt with its pre-generated trajectory
            for i, sample in enumerate(sampled_prompts):
                prompt = sample['input']
                trajectory_info = {}

                print(f"\nProcessing sample {i+1}/{len(sampled_prompts)}...")

                # Use pre-generated trajectory
                responses = all_trajectories[i]
                trajectory_info['prompt'] = apply_chat_template_to_prompt(
                    self.model_sequence[0],
                    prompt
                )
                trajectory_info['responses'] = responses

                # Create prompt for proposer
                # Extract string content for prompt creation
                # prompt_str = prompt[0]['content'] if isinstance(prompt, list) else prompt
                # Properly format multi-turn conversations using chat template
                prompt_str = apply_chat_template_to_prompt(
                    self.model_sequence[0],
                    prompt
                )
                proposer_prompt = self._create_trajectory_prompt(
                    prompt_str,
                    responses,
                    self.discovered_objectives
                )

                # Get objective proposals
                if self.use_api_proposer:
                    proposals = self._propose_objectives_with_api(
                        proposer_prompt,
                        self.objectives_per_trajectory
                    )
                else:
                    proposals = self._propose_objectives_with_local_model(
                        proposer_prompt,
                        self.objectives_per_trajectory
                    )
                
                trajectory_info['proposals'] = proposals
                trajectories_log.append(trajectory_info)

                # Filter out duplicates
                new_proposals_this_sample = []
                for proposal in proposals:
                    if proposal not in self.seen_objectives:
                        iteration_proposals.append(proposal)
                        new_proposals_this_sample.append(proposal)
                        self.seen_objectives.add(proposal)
                trajectory_info['new_proposals'] = new_proposals_this_sample

            self.discovery_stats['total_proposals'] += len(iteration_proposals)
            
            print(f"\nGenerated {len(iteration_proposals)} new objective proposals")
            
            if self.logger:
                # Log trajectory details
                self.logger.info("\nResponse Trajectories and Proposals:")
                for idx, traj_info in enumerate(trajectories_log):
                    self.logger.info(f"\n  Trajectory {idx+1}:")
                    # Log prompt
                    prompt_display = traj_info['prompt']
                    prompt_display = prompt_display.replace('\n', ' ')
                    self.logger.info(f"    Prompt: {prompt_display}")
                    
                    # Log response trajectory
                    self.logger.info(f"    Response trajectory ({len(traj_info['responses'])} models):")
                    for j, resp in enumerate(traj_info['responses']):  # Show first 2 responses
                        resp_display = resp
                        resp_display = resp_display.replace('\n', ' ')
                        self.logger.info(f"      Model {j+1}: {resp_display}")
                    # if len(traj_info['responses']) > 2:
                    #     self.logger.info(f"      ... and {len(traj_info['responses'])-2} more responses")
                    
                    # Log proposals
                    self.logger.info(f"    Proposals from this trajectory ({len(traj_info['proposals'])} total):")
                    for j, prop in enumerate(traj_info['proposals'], 1):
                        self.logger.info(f"      {j}. {prop}")
                    
                    # Log new unique proposals
                    if traj_info['new_proposals']:
                        self.logger.info(f"    New unique proposals added: {len(traj_info['new_proposals'])}")
                        for j, prop in enumerate(traj_info['new_proposals'], 1):
                            self.logger.info(f"      • {prop}")
                    else:
                        self.logger.info(f"    No new unique proposals (all were duplicates)")
                
                # if len(trajectories_log) > 3:
                #     self.logger.info(f"\n  ... and {len(trajectories_log)-3} more trajectories processed")
                
                # Log iteration proposals summary
                self.logger.info("\n" + "="*60)
                self.logger.info("ITERATION PROPOSALS SUMMARY")
                self.logger.info("="*60)
                self.logger.info(f"Total new unique proposals this iteration: {len(iteration_proposals)}")
                if iteration_proposals:
                    self.logger.info("\nAll proposals from this iteration:")
                    for idx, prop in enumerate(iteration_proposals, 1):  # Show up to 10
                        self.logger.info(f"  {idx}. {prop}")
                    # if len(iteration_proposals) > 10:
                    #     self.logger.info(f"  ... and {len(iteration_proposals)-10} more proposals")

            # 3. Verify each proposed objective
            # Shuffle proposals to randomize verification order
            random.shuffle(iteration_proposals)

            if self.logger and iteration_proposals:
                self.logger.info("\n" + "="*60)
                self.logger.info("PHASE 3: OBJECTIVE VERIFICATION")
                self.logger.info("="*60)
                self.logger.info(f"Verifying {len(iteration_proposals)} proposals (shuffled)...")

            for obj_idx, objective in enumerate(iteration_proposals):
                if len(self.discovered_objectives) >= self.k:
                    break
                
                print(f"\nVerifying objective {obj_idx+1}/{len(iteration_proposals)}: {objective}...")

                is_valid, details = self._verify_objective(objective)

                if is_valid:
                    print(f"✓ Valid objective found: {objective[:50]}...")
                    if self.logger:
                        self.logger.info(f"\n✓ ACCEPTED: {objective}")
                        self.logger.info(f"  Interpretability Score: {details['interpretability_score']:.4f}")
                        self.logger.info(f"  Trend Type: {details['trend_type']}")
                        self.logger.info(f"  Trend Error: {details['trend_error']:.4f}")
                    self.discovered_objectives.append(objective)
                else:
                    print(f"✗ Objective rejected")
                    if self.logger:
                        self.logger.info(f"\n✗ REJECTED: {objective}")
                        if not details['interpretable']:
                            self.logger.info(f"  Failed interpretability (score: {details.get('interpretability_score', 'N/A')})")
                        if not details['follows_trend']:
                            self.logger.info(f"  Failed trend check (error: {details.get('trend_error', 'N/A')})")
                    self.rejected_objectives.append({
                        'objective': objective,
                        'reason': details
                    })

            ################################################################################################
            # 4. Re-sample prompts with highest residuals for next iteration (JUST ADDED THIS)
            # - This is done to make it more similar to how VibeCheck works
            # - VibeCheck discovers new vibes based on the prompt-outputs which are misclassified
            ################################################################################################
            if self.group_scoring:
                residuals, _ = self._group_calculate_residuals_for_samples(sampled_prompts, self.discovered_objectives)
            else:
                residuals, _ = self._calculate_residuals_for_samples(sampled_prompts, self.discovered_objectives)
            
            # Sort by residual value and take top v samples
            sorted_samples = sorted(
                sampled_prompts,
                key=lambda s: residuals.get(
                    apply_chat_template_to_prompt(self.model_sequence[0], s['input']),
                    0.0
                ),
                reverse=True
            )

            sampled_prompts = sorted_samples[:len(sorted_samples)//2]  # Keep top half with highest residuals
            ################################################################################################
            
            # Track objectives at this iteration
            self.objectives_per_iteration.append({
                'iteration': iteration,
                'objectives': self.discovered_objectives.copy(),
                'num_objectives': len(self.discovered_objectives),
                'proposals_made': len(iteration_proposals),
                'total_proposals': self.discovery_stats['total_proposals']
            })
            
            if self.logger:
                self.logger.info("\n" + "-"*60)
                self.logger.info(f"ITERATION {iteration} SUMMARY")
                self.logger.info("-"*60)
                self.logger.info(f"Proposals generated: {len(iteration_proposals)}")
                self.logger.info(f"Valid objectives found: {len(self.discovered_objectives)}/{self.k}")
                self.logger.info(f"Total proposals so far: {self.discovery_stats['total_proposals']}")
                self.logger.info(f"Rejection rate this iteration: {len([p for p in iteration_proposals])-sum([1 for o in self.discovered_objectives if o in iteration_proposals])}/{len(iteration_proposals) if iteration_proposals else 1}")
            
            # Check if we should continue
            if len(self.discovered_objectives) >= self.k:
                print(f"\n✓ Successfully discovered {self.k} valid objectives!")
                break
            
            if iteration >= self.max_iterations:
                print(f"\n⚠ Reached maximum iterations ({self.max_iterations})")
                print(f"Only found {len(self.discovered_objectives)}/{self.k} valid objectives")
                break

            if len(sampled_prompts) == 0:
                print("\n⚠ No prompts left to sample from - ending discovery")
                break
        
        # Calculate Final Obj-Error for discovered objectives
        final_obj_error, combiner_save_path = self.calculate_final_obj_error(
            discovered_objectives=self.discovered_objectives,
            ground_truth_reward=self.ground_truth_reward,
            num_samples_eval=self.num_samples_final_eval,
            train_test_split_idx=self.train_test_split_idx,
            combination_function_type=self.combination_function_type,
            combination_function_params=self.combination_function_params,
            save_dir=self.output_dir
        )

        self.discovery_stats['final_obj_error'] = final_obj_error
        self.discovery_stats['reward_combiner_path'] = combiner_save_path

        # Calculate final statistics
        self.discovery_stats['time_elapsed'] = time.time() - start_time
        self.discovery_stats['discovered_count'] = len(self.discovered_objectives)
        self.discovery_stats['rejected_count'] = len(self.rejected_objectives)
        self.discovery_stats['acceptance_rate'] = (
            len(self.discovered_objectives) / self.discovery_stats['total_proposals']
            if self.discovery_stats['total_proposals'] > 0 else 0
        )
        self.discovery_stats['objectives_history'] = self.objectives_per_iteration
        
        # Print summary
        print("\n" + "="*60)
        print("DISCOVERY COMPLETE")
        print("="*60)
        print(f"Valid objectives discovered: {len(self.discovered_objectives)}/{self.k}")
        print(f"Total iterations: {iteration}")
        print(f"Total proposals: {self.discovery_stats['total_proposals']}")
        print(f"Acceptance rate: {self.discovery_stats['acceptance_rate']:.2%}")
        print(f"Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")
        print(f"Verification failures:")
        print(f"  - Interpretability: {self.discovery_stats['verification_failures']['interpretability']}")
        print(f"  - Trend: {self.discovery_stats['verification_failures']['trend']}")
        
        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "DISCOVERY COMPLETE".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"\nFINAL SUMMARY:")
            self.logger.info(f"  Valid objectives discovered: {len(self.discovered_objectives)}/{self.k}")
            self.logger.info(f"  Total iterations: {iteration}")
            self.logger.info(f"  Total proposals evaluated: {self.discovery_stats['total_proposals']}")
            self.logger.info(f"  Acceptance rate: {self.discovery_stats['acceptance_rate']:.2%}")
            self.logger.info(f"  Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")
            self.logger.info(f"\nVerification failures:")
            self.logger.info(f"  - Interpretability: {self.discovery_stats['verification_failures']['interpretability']}")
            self.logger.info(f"  - Trend: {self.discovery_stats['verification_failures']['trend']}")
            
            self.logger.info(f"\nFINAL DISCOVERED OBJECTIVES:")
            for i, obj in enumerate(self.discovered_objectives, 1):
                self.logger.info(f"  {i}. {obj}")
            
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "END OF RANDOM DISCOVERY LOG".center(78) + "#")
            self.logger.info("#"*80)
        
        return self.discovered_objectives, self.discovery_stats


class StaticObjectivesDiscovery(BaseObjectivesDiscovery):
    """
    Static baseline for objective discovery.

    This approach:
    1. Randomly samples prompts from the dataset
    2. Generates responses from each model in the sequence
    3. Passes ALL trajectories to the proposer model at once
    4. Asks for exactly k objectives
    5. Returns those k objectives without verification

    This is a one-shot approach with no iteration or verification.
    """

    def __init__(
        self,
        dataset: Union[str, List[Dict[str, str]]],
        model_sequence: List[str],
        k: int = 4,
        samples_per_discovery: int = 10,
        ground_truth_reward=None,  # Optional ground truth reward function for final evaluation
        num_samples_final_eval: int = 25,  # Number of samples for final Obj-Error calculation
        train_test_split_idx: Optional[int] = None,
        combination_function_type: str = 'linear_regression',
        combination_function_params: Optional[Dict[str, Any]] = None,
        **kwargs
    ):
        """
        Initialize static objective discovery.

        Args:
            dataset: Dataset to sample from
            model_sequence: List of model checkpoints
            k: Number of objectives to discover
            samples_per_discovery: Number of prompts to sample for discovery
            ground_truth_reward: Optional RewardFunction for final Obj-Error calculation
            num_samples_final_eval: Number of samples for final evaluation
            train_test_split_idx: Index to split model sequence for train/test
            combination_function_type: Type of combination function for Obj-Error
            combination_function_params: Parameters for combination function
            **kwargs: Additional arguments for base class
        """
        super().__init__(dataset, model_sequence, k, **kwargs)
        self.samples_per_discovery = samples_per_discovery
        self.ground_truth_reward = ground_truth_reward
        self.num_samples_final_eval = num_samples_final_eval
        self.combination_function_type = combination_function_type
        self.combination_function_params = combination_function_params or {}

        # Import ObjectivesFit for Obj-Error calculation
        from src.calc_objectives_fit import ObjectivesFit
        self.ObjectivesFit = ObjectivesFit

        self.train_test_split_idx = train_test_split_idx
        if self.train_test_split_idx is None:
            self.train_test_split_idx = len(model_sequence) // 2

    def _create_batch_trajectory_prompt(
        self,
        trajectories: List[Dict[str, Any]]
    ) -> str:
        """
        Create a prompt for the proposer model to analyze multiple trajectories at once.

        Args:
            trajectories: List of trajectory dictionaries with 'prompt' and 'responses' keys

        Returns:
            Formatted prompt for proposer model
        """
        # Build the trajectories description
        trajectories_text = ""

        for i, trajectory in enumerate(trajectories, 1):
            trajectories_text += f"=== Trajectory {i} ===\n"
            trajectories_text += f"Input Prompt: {trajectory['prompt']}\n\n"

            for j, response in enumerate(trajectory['responses']):
                trajectories_text += f"Model Iteration {j+1} Response:\n{response}\n\n"

            trajectories_text += "\n"

        # Use the prompt template from constants
        # No existing objectives since this is one-shot
        full_prompt = OBJECTIVE_DISCOVERY_PROMPT.format(
            trajectory_count=len(trajectories),
            trajectories=trajectories_text,
            num_objectives=self.k,  # Ask for exactly k objectives
            existing_objectives_section=""  # No existing objectives
        )

        return full_prompt

    def obtain_objectives(self) -> Tuple[List[str], Dict[str, Any]]:
        """
        Discover objectives using static one-shot approach.

        Returns:
            Tuple of (discovered_objectives, statistics)
        """
        start_time = time.time()

        print(f"\nStarting Static Objective Discovery")
        print(f"Target: {self.k} objectives")
        print(f"Dataset size: {len(self.dataset)} samples")
        print(f"Model sequence length: {len(self.model_sequence)} models")
        print(f"Samples for discovery: {self.samples_per_discovery}")
        print("="*60)

        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "STARTING STATIC OBJECTIVES DISCOVERY".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"Configuration:")
            self.logger.info(f"  Target objectives (k): {self.k}")
            self.logger.info(f"  Dataset size: {len(self.dataset)} samples")
            self.logger.info(f"  Model sequence: {len(self.model_sequence)} models")
            self.logger.info(f"  Samples for discovery: {self.samples_per_discovery}")
            self.logger.info("")

        # 1. Sample prompts randomly
        sampled_prompts = random.sample(
            self.dataset,
            min(self.samples_per_discovery, len(self.dataset))
        )

        # sampled_prompts = self.dataset[:self.samples_per_discovery]

        print(f"Sampled {len(sampled_prompts)} prompts")

        if self.logger:
            self.logger.info("="*60)
            self.logger.info("PHASE 1: TRAJECTORY GENERATION")
            self.logger.info("="*60)
            self.logger.info(f"Sampled {len(sampled_prompts)} prompts from dataset")

        # 2. Generate trajectories for all sampled prompts using batched generation
        trajectories = []

        # Extract all prompts for batched generation
        all_prompts = [sample['input'] for sample in sampled_prompts]

        # Generate responses for all prompts from each model using batched generation
        # responses_by_model[model_idx] = list of responses for all prompts
        responses_by_model = []
        for model_idx, model_path in enumerate(self.model_sequence):
            print(f"\nGenerating responses from model {model_idx + 1}/{len(self.model_sequence)}: {model_path}")
            try:
                model_responses = generate_responses_batched_util(
                    model_path=model_path,
                    prompts=all_prompts,
                    # max_new_tokens=128,
                    max_new_tokens=512,
                    batch_size=8,
                    temperature=0.7,
                    top_p=0.9,
                    # use_cache=True  # Enable disk-based caching
                )
                responses_by_model.append(model_responses)
            except Exception as e:
                print(f"Error generating responses from model {model_path}: {e}")
                if self.logger:
                    self.logger.error(f"Failed to generate responses from model {model_path}: {e}")
                # Fill with empty responses to maintain structure
                responses_by_model.append([""] * len(all_prompts))

        # Build trajectories from the batched responses
        for i, sample in enumerate(sampled_prompts):
            prompt = sample['input']

            # Collect responses for this prompt from all models
            responses = [responses_by_model[model_idx][i] for model_idx in range(len(self.model_sequence))]

            # Format prompt using chat template
            prompt_str = apply_chat_template_to_prompt(
                self.model_sequence[0],
                prompt
            )
            trajectories.append({
                'prompt': prompt_str,
                'responses': responses
            })

            if self.logger and i < 3:  # Log first 3 trajectories
                self.logger.info(f"\nTrajectory {i+1}:")
                prompt_display = prompt_str[:200].replace('\n', ' ')
                self.logger.info(f"  Prompt: {prompt_display}...")
                self.logger.info(f"  Generated {len(responses)} responses")

        print(f"\nGenerated {len(trajectories)} complete trajectories")

        if self.logger:
            self.logger.info(f"\nSuccessfully generated {len(trajectories)} trajectories")

        # 3. Create a single prompt with all trajectories
        if self.logger:
            self.logger.info("\n" + "="*60)
            self.logger.info("PHASE 2: OBJECTIVE DISCOVERY")
            self.logger.info("="*60)
            self.logger.info(f"Sending {len(trajectories)} trajectories to proposer model")
            self.logger.info(f"Requesting exactly {self.k} objectives")

        proposer_prompt = self._create_batch_trajectory_prompt(trajectories)

        # 4. Get exactly k objective proposals
        print(f"\nRequesting {self.k} objectives from proposer model...")
        
        if self.use_api_proposer:
            proposals = self._propose_objectives_with_api(
                proposer_prompt,
                self.k  # Request exactly k objectives
            )
        else:
            proposals = self._propose_objectives_with_local_model(
                proposer_prompt,
                self.k  # Request exactly k objectives
            )

        # 5. Take the first k proposals (or all if fewer)
        self.discovered_objectives = proposals[:self.k]

        print(f"\nDiscovered {len(self.discovered_objectives)} objectives")

        if self.logger:
            self.logger.info(f"\nProposer returned {len(proposals)} objectives")
            self.logger.info("\nDiscovered objectives:")
            for idx, obj in enumerate(self.discovered_objectives, 1):
                self.logger.info(f"  {idx}. {obj}")

        # Calculate Final Obj-Error for discovered objectives
        final_obj_error, combiner_save_path = self.calculate_final_obj_error(
            discovered_objectives=self.discovered_objectives,
            ground_truth_reward=self.ground_truth_reward,
            num_samples_eval=self.num_samples_final_eval,
            train_test_split_idx=self.train_test_split_idx,
            combination_function_type=self.combination_function_type,
            combination_function_params=self.combination_function_params,
            save_dir=self.output_dir
        )

        self.discovery_stats['final_obj_error'] = final_obj_error
        self.discovery_stats['reward_combiner_path'] = combiner_save_path

        # Calculate statistics
        self.discovery_stats['time_elapsed'] = time.time() - start_time
        self.discovery_stats['discovered_count'] = len(self.discovered_objectives)
        self.discovery_stats['total_proposals'] = len(proposals)
        self.discovery_stats['trajectories_used'] = len(trajectories)
        self.discovery_stats['prompts_sampled'] = len(sampled_prompts)

        # Final summary
        print("\n" + "="*60)
        print("DISCOVERY COMPLETE")
        print("="*60)
        print(f"Objectives discovered: {len(self.discovered_objectives)}/{self.k}")
        print(f"Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")
        if 'final_obj_error' in self.discovery_stats and self.discovery_stats['final_obj_error'] is not None:
            print(f"Final Obj-Error: {self.discovery_stats['final_obj_error']:.6f}")
        print("\nDiscovered objectives:")
        for i, obj in enumerate(self.discovered_objectives, 1):
            print(f"  {i}. {obj}")

        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "STATIC DISCOVERY COMPLETE".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"Final statistics:")
            self.logger.info(f"  Objectives discovered: {len(self.discovered_objectives)}/{self.k}")
            self.logger.info(f"  Trajectories used: {len(trajectories)}")
            self.logger.info(f"  Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")
            if 'final_obj_error' in self.discovery_stats and self.discovery_stats['final_obj_error'] is not None:
                self.logger.info(f"  Final Obj-Error: {self.discovery_stats['final_obj_error']:.6f}")
            self.logger.info("")

        return self.discovered_objectives, self.discovery_stats


class ProposedObjectivesDiscovery(BaseObjectivesDiscovery):
    """
    Proposed method for objective discovery using greedy Matching Pursuit approach.
    
    This approach iteratively discovers objectives by:
    1. Objectives Discovery:
       - Identifying informative samples with highest residuals
       - Discovering candidate objectives from those samples
       - Selecting the best objective that maximizes Obj-Fit
    2. Objectives Verification:
       - Verifying human-interpretability
       - Verifying predictable trend
    """
    
    def __init__(
        self,
        dataset: Union[str, List[Dict[str, str]]],
        model_sequence: List[str],
        ground_truth_reward,  # RewardFunction instance
        k: int = 10,
        x_cand_size: int = 100,  # Size of X_cand for informative samples
        x_disc_size: int = 10,  # Size of X_disc (v in the paper)
        objectives_per_trajectory: int = 3,  # Number of objectives to request per trajectory
        num_parallel_trajectories: int = 1,  # Number of trajectories to process at once
        num_samples_select_best: int = 20,  # Number of samples for objective selection
        num_samples_final_eval: int = 25, # Number of samples for final Obj-Error calculation
        combination_function_type: str = 'linear_regression',
        combination_function_params: Optional[Dict[str, Any]] = None,
        train_test_split_idx: Optional[int] = None,
        max_iterations: int = 50,
        use_random_sampling: bool = False,  # If True, randomly sample instead of using residuals
        **kwargs
    ):
        """
        Initialize the proposed objective discovery method.
        
        Args:
            dataset: Dataset to sample from
            model_sequence: List of model checkpoints [π_θ_1, ..., π_θ_T]
            ground_truth_reward: RewardFunction instance representing R*
            k: Number of objectives to discover
            x_cand_size: Size of candidate sample set for informative sample selection
            x_disc_size: Number of most informative samples to use for discovery (v)
            objectives_per_trajectory: Number of objectives to request per trajectory
            num_parallel_trajectories: Number of trajectories to process together in discovery prompt
            num_samples_select_best: Number of random samples for objective selection phase
            combination_function_type: Type of g function ('linear', 'linear_regression', etc.)
            combination_function_params: Parameters for combination function
            train_test_split_idx: Index to split model sequence for train/test
            max_iterations: Maximum iterations before stopping
            use_random_sampling: If True, randomly sample x_disc instead of using residuals
            **kwargs: Additional arguments for base class
        """
        super().__init__(dataset, model_sequence, k, **kwargs)
        
        self.ground_truth_reward = ground_truth_reward
        self.x_cand_size = x_cand_size
        self.x_disc_size = x_disc_size
        self.objectives_per_trajectory = objectives_per_trajectory
        self.num_parallel_trajectories = num_parallel_trajectories
        self.num_samples_select_best = num_samples_select_best
        self.num_samples_final_eval = num_samples_final_eval
        self.combination_function_type = combination_function_type
        self.combination_function_params = combination_function_params or {}
        self.max_iterations = max_iterations
        self.use_random_sampling = use_random_sampling
        
        # Set train-test split
        self.train_test_split_idx = train_test_split_idx
        if self.train_test_split_idx is None:
            self.train_test_split_idx = len(model_sequence) // 2
        
        # Import calc_objectives_fit
        self.ObjectivesFit = ObjectivesFit

        # Track objectives and residuals at each iteration
        self.iteration_history = []
        self.current_objectives = []  # R̂^i
        self.residuals_cache = {}

        # Initialize cache for in-context examples
        # self.cache_file = "./custom_rubrics_cache.json"
        self.custom_examples_cache = self._load_cache('examples')

        # Get dataset short name for cache keys
        self.dataset_short_name = DATASET_NAMES_DICT.get(self.dataset_name, 'unknown')

    def _load_cache(self, cache_type: str) -> Dict[str, str]:
        """Load custom cache from file.

        Args:
            cache_type: Either 'rubrics', 'descriptions', or 'examples'
        """
        # Modify cache filename based on type
        # cache_filename = self.cache_file.replace('.json', f'_{cache_type}.json')
        filename = "{}_cache.json".format(cache_type)
        if self.output_dir:
            cache_filename = os.path.join(self.output_dir, filename)
        else:
            cache_filename = None

        if os.path.exists(cache_filename):
            try:
                with open(cache_filename, 'r') as f:
                    return json.load(f)
            except (json.JSONDecodeError, IOError):
                print(f"Warning: Could not load {cache_type} cache from {cache_filename}, starting fresh")
                return {}
        return {}

    def _save_cache(self, cache_type: str, cache_data: Dict[str, str]):
        """Save custom cache to file.

        Args:
            cache_type: Either 'rubrics', 'descriptions', or 'examples'
            cache_data: The cache dictionary to save
        """
        # Modify cache filename based on type
        # cache_filename = self.cache_file.replace('.json', f'_{cache_type}.json')
        filename = "{}_cache.json".format(cache_type)
        if self.output_dir:
            cache_filename = os.path.join(self.output_dir, filename)
        else:
            return

        try:
            with open(cache_filename, 'w') as f:
                json.dump(cache_data, f, indent=2)
        except IOError as e:
            print(f"Warning: Could not save {cache_type} cache to {cache_filename}: {e}")

    def _generate_in_context_examples_from_trajectory(
        self,
        objective: str,
        prompt: str,
        responses: List[str]
    ) -> str:
        """Generate in-context examples from a response trajectory.

        Args:
            objective: The objective name/description
            prompt: The input prompt/query
            responses: List of responses from the trajectory (from worst to best)

        Returns:
            Formatted in-context example string
        """
        if not responses:
            return ""

        # Use the last trajectory response as it should be the best
        # Format similar to existing examples in the cache
        examples_text = f"Query: {prompt}\n\n"

        # Add first response with low score (worst in trajectory)
        if len(responses) >= 1:
            examples_text += f"Response 1 (Score: 1):\n{responses[0]}\n\n"

        # Add middle responses without scores (optional, based on trajectory length)
        for i, response in enumerate(responses[1:-1], start=2):
            examples_text += f"Response {i}:\n{response}\n\n"

        # Add last response with high score (best in trajectory)
        if len(responses) > 1:
            examples_text += f"Response {len(responses)} (Score: 10):\n{responses[-1]}"

        return examples_text

    def _identify_relevant_trajectory(
        self,
        objective: str,
        batch_trajectory_infos: List[Dict[str, Any]]
    ) -> int:
        """Identify which trajectory is most relevant to a given objective.

        Args:
            objective: The discovered objective
            batch_trajectory_infos: List of trajectory info dicts with 'prompt' and 'responses'

        Returns:
            Index of the most relevant trajectory (0-based)
        """
        # Format trajectories for prompt
        trajectories_text = ""
        for idx, traj_info in enumerate(batch_trajectory_infos):
            trajectory_text = self._format_single_trajectory(
                traj_info['prompt'],
                traj_info['responses'],
                trajectory_num=idx + 1
            )
            trajectories_text += trajectory_text + "\n"

        # Create prompt
        prompt = TRAJECTORY_RELEVANCE_PROMPT.format(
            objective=objective,
            trajectories=trajectories_text,
            num_trajectories=len(batch_trajectory_infos)
        )

        # Try multiple times with retries
        from .constants import MAX_NUM_SCORING_RETRIES

        for attempt in range(MAX_NUM_SCORING_RETRIES):
            try:
                if self.use_api_proposer:
                    client = OpenAI(api_key=OPENAI_API_KEY)
                    response = client.chat.completions.create(
                        model=self.proposer_model,
                        messages=[
                            {"role": "system", "content": "You are an expert at analyzing model behavior."},
                            {"role": "user", "content": prompt}
                        ],
                        temperature=0.0,
                        max_tokens=10
                    )
                    response_text = response.choices[0].message.content.strip()
                else:
                    # Use local model
                    response_text = generate_huggingface_response(
                        self.proposer_model,
                        prompt,
                        max_new_tokens=10,
                        temperature=0.0
                    )

                # Extract trajectory number
                import re
                match = re.search(r'\d+', response_text)
                if match:
                    traj_num = int(match.group())
                    # Convert to 0-based index
                    traj_idx = traj_num - 1
                    # Validate index
                    if 0 <= traj_idx < len(batch_trajectory_infos):
                        return traj_idx

                # If parsing failed, try again
                if attempt < MAX_NUM_SCORING_RETRIES - 1:
                    print(f"  Warning: Could not parse trajectory index from: {response_text}. Retrying...")
                    continue
                else:
                    print(f"  Warning: Could not parse trajectory index after {MAX_NUM_SCORING_RETRIES} attempts. Using default.")
                    return 0

            except Exception as e:
                if attempt < MAX_NUM_SCORING_RETRIES - 1:
                    print(f"  Error identifying relevant trajectory (attempt {attempt + 1}): {e}. Retrying...")
                    time.sleep(1)  # Brief delay before retry
                    continue
                else:
                    print(f"  Error identifying relevant trajectory after {MAX_NUM_SCORING_RETRIES} attempts: {e}")
                    # Default to first trajectory on error
                    return 0

        # Should not reach here, but default to first trajectory as fallback
        return 0

    def _store_objective_examples(self, objective: str, prompt: str, responses: List[str]):
        """Generate and cache in-context examples for an objective.

        Args:
            objective: The objective name/description
            prompt: The input prompt used in discovery
            responses: Response trajectory used to discover the objective
        """
        # Create cache key
        cache_key = f"{self.dataset_short_name}_{objective.lower()}"

        # Check if already in cache
        if cache_key in self.custom_examples_cache:
            print(f"In-context examples already cached for '{objective}'")
            return

        # Generate examples from trajectory
        examples = self._generate_in_context_examples_from_trajectory(
            objective, prompt, responses
        )

        # Store in cache
        self.custom_examples_cache[cache_key] = examples
        self._save_cache('examples', self.custom_examples_cache)

        print(f"In-context examples generated and cached for '{objective}'")

    def _calculate_residuals_for_samples(
        self,
        samples: List[Dict[str, str]],
        current_objectives: List[str]
    ) -> Dict[str, float]:  # Returns: Tuple[Dict[str, float], Dict[str, List[str]]]
        """
        Calculate average residuals for given samples using current objectives.
        
        Args:
            samples: List of input samples
            current_objectives: Current set of discovered objectives
            
        Returns:
            Dictionary mapping sample inputs to average residuals
            # Also returns: Dictionary mapping sample inputs to test model responses
        """
        print(f"Calculating residuals for {len(samples)} samples...")

        # Create ObjectivesFit instance with current objectives - enable batching
        obj_fit_calc = self.ObjectivesFit(
            dataset=samples,
            model_sequence=self.model_sequence,
            ground_truth_objective=self.ground_truth_reward,
            combination_function_type=self.combination_function_type,
            combination_function_params=self.combination_function_params,
            num_samples=len(samples),
            train_test_split_idx=self.train_test_split_idx,
            scorer_model=self.scorer_model,
            device=self.device,
            cache_responses=True,
            save_dir=self.output_dir,
            use_different_prompts=False,
            dataset_type=DATASET_NAMES_DICT[self.dataset_name],
            use_detailed_rubric=True,
            batching=True,  # Enable batched generation
            batch_size=8,   # Use reasonable batch size
            model_cache_size=1,  # Keep up to 3 models in memory
            normalize_scores=True,  # Normalize scores to [0, 1] for objective discovery
            logger=self.logger,
            max_concurrent=self.max_concurrent
        )

        residuals_by_sample = {}
        test_responses_by_sample = {}  # Store responses from test models
        all_responses_by_input_key = {}  # Store ALL model responses for reuse in discovery

        # Collect all input texts for batching
        input_texts = [sample['input'] for sample in samples]
        train_models = self.model_sequence[:self.train_test_split_idx]
        test_models = self.model_sequence[self.train_test_split_idx:]

        # If no objectives yet, residuals are just ground truth squared
        if not current_objectives:
            residuals_per_sample = {i: [] for i in range(len(samples))}

            # Generate responses from ALL models (for discovery reuse)
            print("Generating responses from all models for discovery...")
            for model_idx, model_path in enumerate(self.model_sequence):
                is_test_model = model_idx >= self.train_test_split_idx
                print(f"Processing {'test' if is_test_model else 'train'} model: {model_path}")
                responses = obj_fit_calc._generate_response_batched(model_path, input_texts, batch_size=8)

                for sample_idx, (sample, response) in enumerate(zip(samples, responses)):
                    input_text = sample['input']
                    input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

                    # Store in all_responses_by_input_key
                    if input_key not in all_responses_by_input_key:
                        all_responses_by_input_key[input_key] = []
                    all_responses_by_input_key[input_key].append(response)

                    # Store test model responses separately and calc residuals
                    if is_test_model:
                        if input_key not in test_responses_by_sample:
                            test_responses_by_sample[input_key] = []
                        test_responses_by_sample[input_key].append(response)

                        gt = obj_fit_calc._get_ground_truth(input_key, response)
                        residuals_per_sample[sample_idx].append(gt ** 2)

                print(f"Processed {len(samples)} samples for model")

            # Average residuals across test models
            for sample_idx, sample in enumerate(samples):
                input_text = sample['input']
                avg_residual = sum(residuals_per_sample[sample_idx]) / len(residuals_per_sample[sample_idx]) if residuals_per_sample[sample_idx] else 0.0
                input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
                residuals_by_sample[input_key] = avg_residual
        else:
            # Fit combination function on training models
            train_features = []
            train_targets = []

            for model_idx, model_path in enumerate(tqdm(train_models, desc="Training combination function")):
                # Generate all responses at once for this model
                responses = obj_fit_calc._generate_response_batched(model_path, input_texts, batch_size=8)

                # Process each response
                for sample_idx, (sample, response) in enumerate(zip(samples, responses)):
                    input_text = sample['input']
                    input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

                    # Store in all_responses_by_input_key (train models first)
                    if input_key not in all_responses_by_input_key:
                        all_responses_by_input_key[input_key] = []
                    if len(all_responses_by_input_key[input_key]) == model_idx:
                        all_responses_by_input_key[input_key].append(response)

                    # Score with current objectives
                    objective_scores = {}
                    for i, obj_desc in enumerate(current_objectives):
                        score = obj_fit_calc._score_with_objective(
                            input_key,
                            response,
                            obj_desc
                        )
                        objective_scores[f"obj_{i}"] = score

                    # Get ground truth
                    gt = obj_fit_calc._get_ground_truth(input_key, response)

                    train_features.append(objective_scores)
                    train_targets.append(gt)

                print(f"Processed {len(samples)} samples for model")

            # Fit combination function
            obj_fit_calc.combination_function, obj_fit_calc.obj_coefficients = obj_fit_calc._fit_combination_function(
                current_objectives,
                train_features,
                train_targets
            )

            # Log coefficients if linear regression was used
            if obj_fit_calc.obj_coefficients is not None and self.logger:
                self.logger.info("\n--- Obj Coefficients ---")
                self.logger.info(f"Intercept: {obj_fit_calc.obj_coefficients['intercept']:.4f}")
                for obj_name, coef in obj_fit_calc.obj_coefficients['coefficients'].items():
                    self.logger.info(f"  {obj_name}: {coef:.4f}")
                self.logger.info("--------------------------------------")

            # Calculate residuals on test models with batching
            residuals_per_sample = {i: [] for i in range(len(samples))}

            for test_idx, model_path in enumerate(test_models):
                model_idx = self.train_test_split_idx + test_idx  # Global model index
                print(f"Processing test model: {model_path}")
                responses = obj_fit_calc._generate_response_batched(model_path, input_texts, batch_size=8)

                for sample_idx, (sample, response) in enumerate(zip(samples, responses)):
                    input_text = sample['input']
                    input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

                    # Store in all_responses_by_input_key (test models after train)
                    if len(all_responses_by_input_key[input_key]) == model_idx:
                        all_responses_by_input_key[input_key].append(response)

                    # Store test responses separately
                    if input_key not in test_responses_by_sample:
                        test_responses_by_sample[input_key] = []
                    test_responses_by_sample[input_key].append(response)

                    residual = obj_fit_calc._calculate_residual(
                        input_key,
                        response,
                        current_objectives
                    )
                    residuals_per_sample[sample_idx].append(residual)

                print(f"Processed {len(samples)} samples for model")

            # Average residuals across models
            for sample_idx, sample in enumerate(samples):
                input_text = sample['input']
                avg_residual = sum(residuals_per_sample[sample_idx]) / len(residuals_per_sample[sample_idx]) if residuals_per_sample[sample_idx] else 0.0
                input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
                residuals_by_sample[input_key] = avg_residual

        # Cleanup
        obj_fit_calc.cleanup_model_cache()

        return residuals_by_sample, test_responses_by_sample, all_responses_by_input_key

    # def _group_calculate_residuals_for_samples(
    #     self,
    #     samples: List[Dict[str, str]],
    #     current_objectives: List[str]
    # ) -> Tuple[Dict[str, float], Dict[str, List[str]]]:
    #     """
    #     Calculate average residuals for given samples using current objectives with group scoring.

    #     This method uses group scoring where all responses (from train and test models) are
    #     scored together for better calibration, similar to the implementation in calc_objectives_fit.py.

    #     Args:
    #         samples: List of input samples
    #         current_objectives: Current set of discovered objectives

    #     Returns:
    #         Tuple of:
    #         - Dictionary mapping sample inputs to average residuals
    #         - Dictionary mapping sample inputs to LIST of test model responses (handles multiple test models)
    #     """
    #     print(f"Calculating residuals with GROUP SCORING for {len(samples)} samples...")

    #     # Create ObjectivesFit instance with current objectives - enable batching and group scoring
    #     obj_fit_calc = self.ObjectivesFit(
    #         dataset=samples,
    #         model_sequence=self.model_sequence,
    #         ground_truth_objective=self.ground_truth_reward,
    #         combination_function_type=self.combination_function_type,
    #         combination_function_params=self.combination_function_params,
    #         num_samples=len(samples),
    #         train_test_split_idx=self.train_test_split_idx,
    #         scorer_model=self.scorer_model,
    #         device=self.device,
    #         cache_responses=True,
    #         use_different_prompts=False,
    #         dataset_type=DATASET_NAMES_DICT[self.dataset_name],
    #         use_detailed_rubric=True,
    #         batching=True,  # Enable batched generation
    #         group_scoring=True,  # Enable group scoring
    #         batch_size=8,   # Use reasonable batch size
    #         model_cache_size=1,  # Keep models in memory
    #         normalize_scores=True,  # Normalize scores to [0, 1] for objective discovery
    #         logger=self.logger
    #     )

    #     residuals_by_sample = {}
    #     test_responses_by_sample = {}  # Will store LIST of responses from ALL test models for each sample

    #     # Split models into train and test
    #     train_models = self.model_sequence[:self.train_test_split_idx]
    #     test_models = self.model_sequence[self.train_test_split_idx:]
    #     all_models = self.model_sequence  # Use all models for group context

    #     # Collect all input texts for processing
    #     input_texts = [sample['input'] for sample in samples]

    #     # First, collect ALL responses from both train and test models for group scoring
    #     print("Collecting responses from all models for group scoring context...")
    #     all_responses_by_sample = [[] for _ in samples]  # List of lists: [sample_idx][model_idx]

    #     for model_idx, model_path in enumerate(all_models):
    #         is_test_model = model_idx >= self.train_test_split_idx
    #         model_type = "test" if is_test_model else "train"
    #         local_idx = model_idx - self.train_test_split_idx if is_test_model else model_idx

    #         print(f"Collecting responses from {model_type} model {local_idx + 1}/{len(test_models) if is_test_model else len(train_models)}: {model_path}")

    #         # Generate all responses at once for this model
    #         responses = obj_fit_calc._generate_response_batched(model_path, input_texts, batch_size=8)

    #         # Store responses organized by sample
    #         for sample_idx, response in enumerate(responses):
    #             all_responses_by_sample[sample_idx].append(response)

    #             # If this is a test model, store in test_responses_by_sample
    #             # Properly handle multiple test models by appending to a list
    #             if is_test_model:
    #                 input_text = samples[sample_idx]['input']
    #                 input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

    #                 # Initialize list for this input_key if not present
    #                 if input_key not in test_responses_by_sample:
    #                     test_responses_by_sample[input_key] = []

    #                 # Append this test model's response to the list
    #                 test_responses_by_sample[input_key].append(response)

    #     print(f"Collected responses from all {len(all_models)} models ({len(train_models)} train, {len(test_models)} test)")

    #     # If no objectives yet, residuals are just ground truth squared (using group scoring)
    #     if not current_objectives:
    #         print("No objectives yet - calculating ground truth residuals with group scoring...")
    #         residuals_per_sample = {i: [] for i in range(len(samples))}

    #         # Process each sample
    #         for sample_idx, sample in enumerate(samples):
    #             input_text = sample['input']
    #             input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

    #             # Get all responses for this sample
    #             all_responses = all_responses_by_sample[sample_idx]

    #             # Use group scoring to get ground truth for all responses
    #             all_ground_truths = obj_fit_calc._get_group_ground_truth(
    #                 input_key,
    #                 all_responses,
    #                 denormalize_scores=False
    #             )

    #             # Calculate residuals only for test models
    #             for model_idx in range(self.train_test_split_idx, len(all_models)):
    #                 gt = all_ground_truths[model_idx]
    #                 # Residual is just gt^2 when no objectives
    #                 residuals_per_sample[sample_idx].append(gt ** 2)

    #         # Average residuals across test models
    #         for sample_idx, sample in enumerate(samples):
    #             input_text = sample['input']
    #             input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

    #             avg_residual = sum(residuals_per_sample[sample_idx]) / len(residuals_per_sample[sample_idx]) \
    #                 if residuals_per_sample[sample_idx] else 0.0
    #             residuals_by_sample[input_key] = avg_residual

    #     else:
    #         print(f"Fitting combination function with {len(current_objectives)} objectives using group scoring...")

    #         # Fit combination function on training models using group scoring
    #         train_features = []
    #         train_targets = []

    #         # Process each sample for training
    #         for sample_idx, sample in enumerate(samples):
    #             input_text = sample['input']
    #             input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

    #             # Get all responses for this sample (for group context)
    #             all_responses = all_responses_by_sample[sample_idx]

    #             # Score all responses together for each objective using group scoring
    #             all_scores_by_objective = {}
    #             for obj_idx, obj_desc in enumerate(current_objectives):
    #                 # Get scores for all responses using group scoring
    #                 all_scores = obj_fit_calc._group_score_with_objective(
    #                     input_key,
    #                     all_responses,
    #                     obj_desc
    #                 )
    #                 all_scores_by_objective[f"obj_{obj_idx}"] = all_scores

    #             # Get ground truth for all responses using group scoring
    #             all_ground_truths = obj_fit_calc._get_group_ground_truth(
    #                 input_key,
    #                 all_responses,
    #                 denormalize_scores=False
    #             )

    #             # Collect features and targets only for training models
    #             for model_idx in range(self.train_test_split_idx):
    #                 objective_scores = {}
    #                 for obj_idx in range(len(current_objectives)):
    #                     objective_scores[f"obj_{obj_idx}"] = all_scores_by_objective[f"obj_{obj_idx}"][model_idx]

    #                 train_features.append(objective_scores)
    #                 train_targets.append(all_ground_truths[model_idx])

    #         print(f"Collected {len(train_features)} training samples for combination function")

    #         # Fit combination function
    #         obj_fit_calc.combination_function, obj_fit_calc.obj_coefficients = obj_fit_calc._fit_combination_function(
    #             current_objectives,
    #             train_features,
    #             train_targets
    #         )

    #         # Log coefficients if linear regression was used
    #         if obj_fit_calc.obj_coefficients is not None and self.logger:
    #             self.logger.info("\n--- Group Scoring Obj Coefficients ---")
    #             self.logger.info(f"Intercept: {obj_fit_calc.obj_coefficients['intercept']:.4f}")
    #             for obj_name, coef in obj_fit_calc.obj_coefficients['coefficients'].items():
    #                 self.logger.info(f"  {obj_name}: {coef:.4f}")
    #             self.logger.info("--------------------------------------")

    #         # Calculate residuals on test models using group scoring
    #         print("Calculating residuals on test models with group scoring...")
    #         residuals_per_sample = {i: [] for i in range(len(samples))}

    #         # Process each sample for test residual calculation
    #         for sample_idx, sample in enumerate(samples):
    #             input_text = sample['input']
    #             input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

    #             # Get all responses for this sample (for group context)
    #             all_responses = all_responses_by_sample[sample_idx]

    #             # Calculate residuals for test models using the group-based method
    #             for model_idx in range(self.train_test_split_idx, len(all_models)):
    #                 response = all_responses[model_idx]

    #                 # Use the group-based residual calculation
    #                 residual, predicted, ground_truth = obj_fit_calc._calculate_residual_group(
    #                     input_key,
    #                     response,
    #                     all_responses,
    #                     current_objectives,
    #                     model_idx
    #                 )
    #                 residuals_per_sample[sample_idx].append(residual)

    #         # Average residuals across test models
    #         for sample_idx, sample in enumerate(samples):
    #             input_text = sample['input']
    #             input_key = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

    #             avg_residual = sum(residuals_per_sample[sample_idx]) / len(residuals_per_sample[sample_idx]) \
    #                 if residuals_per_sample[sample_idx] else 0.0
    #             residuals_by_sample[input_key] = avg_residual

    #     # Cleanup
    #     obj_fit_calc.cleanup_model_cache()

    #     # Verify that test_responses_by_sample contains lists with correct number of responses
    #     if test_responses_by_sample:
    #         sample_key = next(iter(test_responses_by_sample))
    #         num_responses = len(test_responses_by_sample[sample_key])
    #         print(f"Each sample has {num_responses} test model responses (expected: {len(test_models)})")

    #     print(f"Group scoring residual calculation complete for {len(samples)} samples")
    #     return residuals_by_sample, test_responses_by_sample

    def _identify_informative_samples(
        self,
        current_objectives: List[str]
    ) -> List[Dict[str, str]]:
        """
        Phase 1: Identify the most informative samples.
        If use_random_sampling=True, randomly samples x_disc.
        Otherwise, selects samples with highest residuals (Equation 11 from paper).
        
        Args:
            current_objectives: Current set of discovered objectives R̂^{i-1}
            
        Returns:
            X_disc: List of most informative samples
        """
        print("\n--- Phase 1: Informative Samples Identification ---")
        if self.logger:
            self.logger.info("\n" + "="*60)
            self.logger.info("PHASE 1: INFORMATIVE SAMPLES IDENTIFICATION")
            self.logger.info("="*60)
            self.logger.info(f"Sampling method: {'Random' if self.use_random_sampling else 'Residual-based'}")
            self.logger.info(f"Current objectives count: {len(current_objectives)}")
            if current_objectives:
                self.logger.info("Current objectives:")
                for i, obj in enumerate(current_objectives, 1):
                    self.logger.info(f"  {i}. {obj}")
        
        # Sample X_cand from dataset
        x_cand = random.sample(
            self.dataset,
            min(self.x_cand_size, len(self.dataset))
        )
        print(f"Sampled {len(x_cand)} candidates from dataset")
        
        if self.use_random_sampling:
            # Random sampling ablation: directly sample x_disc from x_cand
            x_disc = random.sample(x_cand, min(self.x_disc_size, len(x_cand)))
            print(f"Randomly selected {len(x_disc)} samples (ablation mode)")
            
            if self.logger:
                self.logger.info(f"\nRandomly selected {len(x_disc)} samples (ablation mode)")
                self.logger.info("\nRandomly selected samples (up to 10):")
                for i, sample in enumerate(x_disc[:10], 1):
                    # input_text = sample['input'][0]['content'] if isinstance(sample['input'], list) else sample['input']
                    # Properly format multi-turn conversations using chat template
                    input_text = apply_chat_template_to_prompt(
                        self.model_sequence[0],
                        sample['input']
                    )
                    input_display = input_text.replace('\n', ' ')
                    self.logger.info(f"\n  Sample {i}:")
                    self.logger.info(f"    Input: {input_display}")
        else:
            # Original method: use residuals to select most informative samples
            # Calculate residuals for each sample in X_cand
            if self.group_scoring:
                residuals, test_responses, all_responses = self._group_calculate_residuals_for_samples(x_cand, current_objectives)
            else:
                residuals, test_responses, all_responses = self._calculate_residuals_for_samples(x_cand, current_objectives)

            # Sort by residual value and take top v samples
            sorted_samples = sorted(
                x_cand,
                key=lambda s: residuals.get(
                    apply_chat_template_to_prompt(self.model_sequence[0], s['input']),
                    0.0
                ),
                reverse=True
            )

            x_disc = sorted_samples[:self.x_disc_size]

            # Extract responses only for x_disc samples (for reuse in discovery)
            x_disc_responses = {}
            for sample in x_disc:
                input_key = apply_chat_template_to_prompt(self.model_sequence[0], sample['input'])
                if input_key in all_responses:
                    x_disc_responses[input_key] = all_responses[input_key]

            print(f"Selected top {len(x_disc)} samples with highest residuals")
            avg_residual = np.mean([
                residuals[apply_chat_template_to_prompt(self.model_sequence[0], s['input'])]
                for s in x_disc
            ])
            print(f"Average residual of selected samples: {avg_residual:.4f}")

            if self.logger:
                self.logger.info(f"\nSelected {len(x_disc)} most informative samples (highest residuals)")
                self.logger.info(f"Average residual: {avg_residual:.4f}")
                self.logger.info("\nInformative samples (up to 10):")
                for i, sample in enumerate(x_disc[:10], 1):
                    input_text = apply_chat_template_to_prompt(
                        self.model_sequence[0],
                        sample['input']
                    )
                    residual_val = residuals[input_text]
                    input_display = input_text
                    input_display = input_display.replace('\n', ' ')
                    self.logger.info(f"\n  Sample {i}:")
                    self.logger.info(f"    Input: {input_display}")
                    self.logger.info(f"    Residual: {residual_val:.6f}")

                    # Log test model responses
                    if input_text in test_responses:
                        test_models = self.model_sequence[self.train_test_split_idx:]
                        self.logger.info(f"    Test Model Responses:")
                        for model_idx, (model_path, response) in enumerate(zip(test_models, test_responses[input_text])):
                            model_name = model_path.split('/')[-1]  # Get checkpoint name
                            response_display = response.replace('\n', ' ')[:200]  # Truncate long responses
                            self.logger.info(f"      {model_name}: {response_display}")

            return x_disc, x_disc_responses

        # For random sampling, return empty responses dict
        return x_disc, {}
    
    def _discover_candidate_objectives(
        self,
        informative_samples: List[Dict[str, str]],
        current_objectives: List[str],
        pre_generated_responses: Optional[Dict[str, List[str]]] = None
    ) -> Set[str]:
        """
        Phase 2: Generate candidate objectives from informative samples.
        Efficiently generates responses for all samples using batched generation.

        Args:
            informative_samples: X_disc samples with highest residuals
            current_objectives: Already discovered objectives
            pre_generated_responses: Optional dict mapping input_key -> list of responses from all models
                                     (reused from Phase 1 to avoid redundant generation)

        Returns:
            Set of candidate objective descriptions
        """
        print("\n--- Phase 2: Candidate Objectives Discovery ---")
        if self.logger:
            self.logger.info("\n" + "="*60)
            self.logger.info("PHASE 2: CANDIDATE OBJECTIVES DISCOVERY")
            self.logger.info("="*60)
            self.logger.info(f"Analyzing {len(informative_samples)} informative samples")
            self.logger.info(f"Processing {self.num_parallel_trajectories} trajectories at a time")
            if pre_generated_responses:
                self.logger.info(f"Reusing {len(pre_generated_responses)} pre-generated response sets from Phase 1")

        candidates = set()
        trajectories_log = []
        all_batch_data = []  # Collect all batch data for parallel processing

        # Check if we can reuse pre-generated responses
        use_pre_generated = pre_generated_responses and len(pre_generated_responses) > 0
        if use_pre_generated:
            print(f"Reusing pre-generated responses from Phase 1 ({len(pre_generated_responses)} samples)")

        # Phase 1: Build all batch trajectories
        # Process samples in batches of num_parallel_trajectories
        for batch_idx in range(0, len(informative_samples), self.num_parallel_trajectories):
            batch_samples = informative_samples[batch_idx:batch_idx + self.num_parallel_trajectories]
            batch_num = batch_idx // self.num_parallel_trajectories + 1
            total_batches = (len(informative_samples) + self.num_parallel_trajectories - 1) // self.num_parallel_trajectories

            print(f"\nProcessing batch {batch_num}/{total_batches} ({len(batch_samples)} trajectories)...")

            # Extract prompts from batch samples
            batch_prompts = [sample['input'] for sample in batch_samples]

            # Build trajectories - either from pre-generated responses or by generating new ones
            batch_trajectories = []
            batch_trajectory_infos = []

            if use_pre_generated:
                # Use pre-generated responses from Phase 1
                for sample_idx, prompt in enumerate(batch_prompts):
                    prompt_formatted = apply_chat_template_to_prompt(self.model_sequence[0], prompt)

                    if prompt_formatted in pre_generated_responses:
                        responses = pre_generated_responses[prompt_formatted]
                        # Filter out empty responses
                        responses = [r for r in responses if r]
                    else:
                        print(f"  Warning: No pre-generated responses for sample, generating...")
                        responses = []
                        for model_path in self.model_sequence:
                            try:
                                # model_resp = self.generate_responses_batched(model_path, [prompt], 1024, 8)
                                model_resp = self.generate_responses_batched(model_path, [prompt], 512, 8)
                                responses.extend(model_resp)
                            except Exception as e:
                                responses.append("")

                    if not responses:
                        print(f"  No valid responses for sample {batch_idx + sample_idx + 1}")
                        continue

                    trajectory_info = {'prompt': prompt_formatted, 'responses': responses}
                    trajectory_text = self._format_single_trajectory(
                        trajectory_info['prompt'], responses, trajectory_num=len(batch_trajectories) + 1
                    )
                    batch_trajectories.append(trajectory_text)
                    batch_trajectory_infos.append(trajectory_info)
            else:
                # Generate responses for each model in sequence (legacy fallback)
                all_model_responses = []
                for model_idx, model_path in enumerate(self.model_sequence):
                    print(f"  Generating responses from model {model_idx + 1}/{len(self.model_sequence)}")
                    try:
                        # model_responses = self.generate_responses_batched(
                        #     model_path=model_path, prompts=batch_prompts, max_new_tokens=1024, batch_size=8
                        # )
                        model_responses = self.generate_responses_batched(
                            model_path=model_path, prompts=batch_prompts, max_new_tokens=512, batch_size=8
                        )
                        all_model_responses.append(model_responses)
                    except Exception as e:
                        print(f"    Error generating responses from {model_path}: {e}")
                        if self.logger:
                            self.logger.error(f"Failed to generate responses from {model_path}: {e}")
                        all_model_responses.append([""] * len(batch_prompts))

                # Reorganize responses: from [model][sample] to [sample][model]
                for sample_idx in range(len(batch_prompts)):
                    prompt = batch_prompts[sample_idx]
                    responses = [all_model_responses[model_idx][sample_idx]
                               for model_idx in range(len(self.model_sequence))]
                    responses = [r for r in responses if r]

                    if not responses:
                        print(f"  No valid responses for sample {batch_idx + sample_idx + 1}")
                        continue

                    prompt_formatted = apply_chat_template_to_prompt(self.model_sequence[0], prompt)
                    trajectory_info = {'prompt': prompt_formatted, 'responses': responses}
                    trajectory_text = self._format_single_trajectory(
                        trajectory_info['prompt'], responses, trajectory_num=len(batch_trajectories) + 1
                    )
                    batch_trajectories.append(trajectory_text)
                    batch_trajectory_infos.append(trajectory_info)
            
            # Skip if no valid trajectories in batch
            if not batch_trajectories:
                print("  No valid trajectories in this batch")
                continue

            # Create combined prompt for batch and store for later parallel processing
            proposer_prompt = self._create_batch_discovery_prompt(
                batch_trajectories,
                current_objectives
            )

            # Store batch data for parallel processing
            all_batch_data.append({
                'proposer_prompt': proposer_prompt,
                'batch_trajectory_infos': batch_trajectory_infos
            })

        # Phase 2: Call proposer API in parallel for all batches (if using API)
        if self.use_api_proposer and len(all_batch_data) > 1 and hasattr(self, 'async_proposer_client'):
            print(f"\nCalling proposer API in parallel for {len(all_batch_data)} batches...")

            async def run_parallel_proposals():
                semaphore = asyncio.Semaphore(10)  # Limit concurrent API calls

                async def propose_for_batch(batch_data):
                    async with semaphore:
                        return await self._async_propose_objectives_with_api(
                            batch_data['proposer_prompt'],
                            self.objectives_per_trajectory
                        )

                tasks = [propose_for_batch(bd) for bd in all_batch_data]
                return await asyncio.gather(*tasks)

            # Run the async proposals
            try:
                loop = asyncio.get_running_loop()
            except RuntimeError:
                loop = None

            if loop and loop.is_running():
                import nest_asyncio
                nest_asyncio.apply()
                all_proposals = asyncio.get_event_loop().run_until_complete(run_parallel_proposals())
            else:
                all_proposals = asyncio.run(run_parallel_proposals())
        else:
            # Sequential fallback (for local model or single batch)
            all_proposals = []
            for batch_data in all_batch_data:
                if self.use_api_proposer:
                    proposals = self._propose_objectives_with_api(batch_data['proposer_prompt'], self.objectives_per_trajectory)
                else:
                    proposals = self._propose_objectives_with_local_model(batch_data['proposer_prompt'], self.objectives_per_trajectory)
                all_proposals.append(proposals)

        # Phase 3: Process all results
        for batch_data, proposals in zip(all_batch_data, all_proposals):
            batch_trajectory_infos = batch_data['batch_trajectory_infos']
            all_batch_proposals = proposals.copy() if proposals else []

            for traj_info in batch_trajectory_infos:
                # Safely extract objectives for this trajectory
                traj_objectives = proposals[:self.objectives_per_trajectory] if proposals else []
                traj_info['discovered_objectives'] = traj_objectives
                proposals = proposals[self.objectives_per_trajectory:]  # Move to next set
                trajectories_log.append(traj_info)

                # Generate and store in-context examples for each discovered objective
                for obj in traj_objectives:
                    if obj and batch_trajectory_infos:
                        relevant_idx = self._identify_relevant_trajectory(obj, batch_trajectory_infos)
                        relevant_traj = batch_trajectory_infos[relevant_idx]
                        self._store_objective_examples(
                            objective=obj,
                            prompt=relevant_traj['prompt'],
                            responses=relevant_traj['responses']
                        )

            # Add all unique proposals from batch to candidates
            for proposal in all_batch_proposals:
                if proposal and len(proposal) > 10:
                    candidates.add(proposal)

        print(f"Generated {len(candidates)} unique candidate objectives")
        
        if self.logger:
            self.logger.info(f"\nGenerated {len(candidates)} unique candidate objectives")
            
            self.logger.info("\nResponse Trajectories and Discovered Objectives:")
            for idx, traj_info in enumerate(trajectories_log):
                self.logger.info(f"\n  Trajectory {idx+1}:")
                prompt_display = traj_info['prompt']
                prompt_display = prompt_display.replace('\n', ' ')
                self.logger.info(f"    Prompt: {prompt_display}")
                
                self.logger.info(f"    Response trajectory ({len(traj_info['responses'])} models):")
                for j, resp in enumerate(traj_info['responses']):
                    resp_display = resp
                    resp_display = resp_display.replace('\n', ' ')
                    self.logger.info(f"      Model {j+1}: {resp_display}")
                
                self.logger.info(f"    Discovered objectives from this trajectory and the ones below:")
                for j, obj in enumerate(traj_info['discovered_objectives'], 1):
                    self.logger.info(f"      {j}. {obj}")
            
            # Log all unique candidates
            self.logger.info("\nAll Candidate Objectives:")
            for idx, candidate in enumerate(list(candidates), 1):
                self.logger.info(f"  {idx}. {candidate}")

        self.scorer_model.reload_cache()  # Reload cache to ensure it has latest data
        
        return candidates

    # def _format_single_trajectory(
    #     self,
    #     prompt: str,
    #     responses: List[str],
    #     trajectory_num: int
    # ) -> str:
    #     """Format a single trajectory for inclusion in batch prompt."""
    #     trajectory_text = f"==== TRAJECTORY {trajectory_num} ====\n"
    #     trajectory_text += f"Input Prompt: {prompt}\n\n"
        
    #     for i, response in enumerate(responses):
    #         trajectory_text += f"Model Iteration {i+1} Response:\n{response}\n\n"
        
    #     return trajectory_text
    
    # def _create_batch_discovery_prompt(
    #     self,
    #     batch_trajectories: List[str],
    #     existing_objectives: List[str]
    # ) -> str:
    #     """Create prompt for proposer to discover objectives from multiple trajectories."""
    #     # Combine all trajectories
    #     trajectories_text = "\n".join(batch_trajectories)
        
    #     # Prepare existing objectives section if needed
    #     existing_objectives_section = ""
    #     if existing_objectives:
    #         existing_obj_list = "\n".join([f"- {obj}" for obj in existing_objectives])
    #         existing_objectives_section = OBJECTIVE_DISCOVERY_WITH_EXISTING_PROMPT.format(
    #             existing_objectives=existing_obj_list
    #         )
        
    #     # Use the updated prompt template from constants
    #     full_prompt = OBJECTIVE_DISCOVERY_PROMPT.format(
    #         trajectory_count=len(batch_trajectories),
    #         trajectories=trajectories_text,
    #         # num_objectives=self.objectives_per_trajectory * len(batch_trajectories),
    #         num_objectives=self.objectives_per_trajectory,
    #         existing_objectives_section=existing_objectives_section
    #     )
        
    #     return full_prompt
    
    def _create_discovery_prompt(
        self,
        prompt: str,
        responses: List[str],
        existing_objectives: List[str]
    ) -> str:
        """Create prompt for proposer to discover new objectives (legacy single trajectory)."""
        # This method is kept for backward compatibility but now uses batch format with single trajectory
        trajectory_text = self._format_single_trajectory(prompt, responses, 1)
        return self._create_batch_discovery_prompt([trajectory_text], existing_objectives)
    
    def _select_best_objective(
        self,
        candidates: Set[str],
        current_objectives: List[str]
    ) -> Tuple[Optional[str], float]:
        """
        Phase 3: Select the best objective that maximizes Obj-Fit improvement.
        Implements Equation 12 from the paper.
        
        Args:
            candidates: Set of candidate objectives
            current_objectives: Current discovered objectives
            
        Returns:
            Tuple of (best_objective, improvement_score)
        """
        print("\n--- Phase 3: Objectives Selection ---")
        if self.logger:
            self.logger.info("\n" + "="*60)
            self.logger.info("PHASE 3: OBJECTIVES SELECTION (OBJ-ERROR EVALUATION)")
            self.logger.info("="*60)
        
        if not candidates:
            print("No candidates to evaluate")
            return None, 0.0
        
        # Sample new random samples for objective selection
        selection_samples = random.sample(
            self.dataset,
            min(self.num_samples_select_best, len(self.dataset))
        )
        print(f"Sampled {len(selection_samples)} random samples for objective selection")
        
        if self.logger:
            self.logger.info(f"Sampled {len(selection_samples)} random samples for objective selection")
        
        # Calculate baseline Obj-Fit with current objectives
        baseline_fit = 0.0
        obj_fit_calc = None
        if current_objectives:
            obj_fit_calc = self.ObjectivesFit(
                dataset=selection_samples,
                model_sequence=self.model_sequence,
                ground_truth_objective=self.ground_truth_reward,
                combination_function_type=self.combination_function_type,
                combination_function_params=self.combination_function_params,
                num_samples=len(selection_samples),
                train_test_split_idx=self.train_test_split_idx,
                scorer_model=self.scorer_model,
                device=self.device,
                cache_responses=True,
                save_dir=self.output_dir,
                use_different_prompts=False,
                dataset_type=DATASET_NAMES_DICT[self.dataset_name],
                use_detailed_rubric=True,
                batching=True,  # Enable batched generation
                batch_size=8,   # Use reasonable batch size
                model_cache_size=1,  # Keep up to 3 models in memory
                normalize_scores=True,  # Normalize scores to [0, 1] for objective discovery
                logger=self.logger,
                max_concurrent=self.max_concurrent
            )
            baseline_fit = obj_fit_calc.calculate(current_objectives)
            # obj_fit_calc.cleanup_model_cache()

        print(f"Baseline Obj-Error: {baseline_fit:.4f}")
        print(f"Evaluating {len(candidates)} candidates...")
        
        if self.logger:
            self.logger.info(f"Baseline Obj-Error (with {len(current_objectives)} objectives): {baseline_fit:.4f}")
            self.logger.info(f"Evaluating {len(candidates)} candidate objectives...")
            self.logger.info("")
        
        best_objective = None
        # best_improvement = -float('inf')
        best_reduction = float('inf')
        obj_fit_results = []
        
        for i, candidate in enumerate(candidates):
            candidate_start_time = time.time()
            print(f"Evaluating candidate {i+1}/{len(candidates)}: {candidate}...")
            if self.logger:
                self.logger.info(f"Evaluating candidate {i+1}/{len(candidates)}: {candidate}")
            
            # Calculate Obj-Fit with this candidate added
            test_objectives = current_objectives + [candidate]
            
            if (i == 0) and (not current_objectives):
                obj_fit_calc = self.ObjectivesFit(
                    dataset=selection_samples,
                    model_sequence=self.model_sequence,
                    ground_truth_objective=self.ground_truth_reward,
                    combination_function_type=self.combination_function_type,
                    combination_function_params=self.combination_function_params,
                    num_samples=len(selection_samples),
                    train_test_split_idx=self.train_test_split_idx,
                    scorer_model=self.scorer_model,
                    device=self.device,
                    cache_responses=True,
                    save_dir=self.output_dir,
                    use_different_prompts=False,
                    dataset_type=DATASET_NAMES_DICT[self.dataset_name],
                    use_detailed_rubric=True,
                    batching=True,  # Enable batched generation
                    batch_size=8,   # Use reasonable batch size
                    model_cache_size=1,  # Keep up to 3 models in memory
                    normalize_scores=True,  # Normalize scores to [0, 1] for objective discovery
                    logger=self.logger,
                    max_concurrent=self.max_concurrent
                )

            try:
                new_fit = obj_fit_calc.calculate(test_objectives)
                reduction = float(new_fit - baseline_fit)
                
                obj_fit_results.append({
                    'objective': candidate,
                    'obj_error': new_fit,
                    'reduction': reduction
                })
                
                print(f"  Obj-Error: {new_fit:.4f}, Reduction: {reduction:.4f}")
                if self.logger:
                    self.logger.info(f"  Obj-Error: {new_fit:.6f}")
                    self.logger.info(f"  Reduction over baseline: {reduction:.6f}")
                    self.logger.info(f"  Time taken: {time.time() - candidate_start_time:.2f}s")
                
                if reduction < best_reduction:
                    best_reduction = reduction
                    best_objective = candidate

                    # Log linear regression coefficients for the best objective
                    if self.combination_function_type == 'linear_regression' and self.logger:
                        if hasattr(obj_fit_calc, 'obj_coefficients') and obj_fit_calc.obj_coefficients is not None:
                            self.logger.info("\n--- Best Objective Obj Coefficients ---")
                            self.logger.info(f"Objective being evaluated: {candidate[:50]}...")
                            self.logger.info(f"Intercept: {obj_fit_calc.obj_coefficients['intercept']:.4f}")
                            for obj_name, coef in obj_fit_calc.obj_coefficients['coefficients'].items():
                                self.logger.info(f"  {obj_name}: {coef:.4f}")
                            self.logger.info("------------------------------------------------------")

            except Exception as e:
                print(f"  Error calculating Obj-Error: {e}")
            finally:
                pass
                # obj_fit_calc.cleanup_model_cache()
        # Cleanup
        obj_fit_calc.cleanup_model_cache()
        
        if self.logger:
            # Sort results by improvement
            obj_fit_results.sort(key=lambda x: x['reduction'], reverse=False)
            
            self.logger.info("Obj-Error Results for All Candidates:")
            # Show top 10 candidates
            for idx, result in enumerate(obj_fit_results, 1):
                self.logger.info(f"  {idx}. Objective: {result['objective']}")
                self.logger.info(f"     Obj-Error: {result['obj_error']:.6f}")
                self.logger.info(f"     Reduction: {result['reduction']:.6f}")
                if result['objective'] == best_objective:
                    self.logger.info("     >>> SELECTED <<<")
                self.logger.info("")
            
            # if len(obj_fit_results) > 10:
            #     self.logger.info(f"  ... and {len(obj_fit_results)-10} more candidates evaluated")
        
        if best_objective:
            print(f"\nBest objective: {best_objective[:50]}...")
            print(f"Best reduction: {best_reduction:.4f}")
            
            if self.logger:
                self.logger.info("\n" + "-"*40)
                self.logger.info("SELECTED BEST OBJECTIVE:")
                self.logger.info(f"  Objective: {best_objective}")
                self.logger.info(f"  Obj-Error: {baseline_fit + best_reduction:.6f}")
                self.logger.info(f"  Reduction over baseline: {best_reduction:.6f}")
                self.logger.info("-"*40)

        return best_objective, best_reduction
    
    def obtain_objectives(self) -> Tuple[List[str], Dict[str, Any]]:
        """
        Main method implementing the proposed objective discovery algorithm.
        
        Returns:
            Tuple of (discovered_objectives, statistics)
        """
        start_time = time.time()
        
        print("\n" + "="*60)
        print("PROPOSED OBJECTIVES DISCOVERY")
        print("="*60)
        print(f"Target: {self.k} valid objectives")
        print(f"Dataset size: {len(self.dataset)} samples")
        print(f"Model sequence: {len(self.model_sequence)} models")
        print(f"Train/Test split: {self.train_test_split_idx}")
        print(f"Combination function: {self.combination_function_type}")
        print(f"X_cand size: {self.x_cand_size}")
        print(f"X_disc size: {self.x_disc_size}")
        print(f"Num samples for selection: {self.num_samples_select_best}")
        print("="*60)
        
        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "STARTING PROPOSED OBJECTIVES DISCOVERY".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"Configuration:")
            self.logger.info(f"  Target objectives (k): {self.k}")
            self.logger.info(f"  Dataset size: {len(self.dataset)} samples")
            self.logger.info(f"  Model sequence: {len(self.model_sequence)} models")
            self.logger.info(f"  Train/Test split index: {self.train_test_split_idx}")
            self.logger.info(f"  Combination function: {self.combination_function_type}")
            self.logger.info(f"  X_cand size: {self.x_cand_size}")
            self.logger.info(f"  X_disc size: {self.x_disc_size}")
            self.logger.info(f"  Num samples for selection: {self.num_samples_select_best}")
            self.logger.info(f"  Verification epsilon (interpretability): {self.verifier_epsilon_interpretable}")
            self.logger.info(f"  Verification epsilon (trend): {self.verifier_epsilon_trend}")
            self.logger.info("")
        
        iteration = 0
        
        while len(self.discovered_objectives) < self.k and iteration < self.max_iterations:
            iteration += 1
            self.discovery_stats['total_iterations'] = iteration
            
            print(f"\n{'='*60}")
            print(f"ITERATION {iteration}")
            print(f"Current objectives: {len(self.discovered_objectives)}/{self.k}")
            print(f"{'='*60}")
            
            if self.logger:
                self.logger.info("\n" + "*"*80)
                self.logger.info("*" + f"ITERATION {iteration}".center(78) + "*")
                self.logger.info("*"*80)
                self.logger.info(f"Progress: {len(self.discovered_objectives)}/{self.k} objectives discovered")
                if self.discovered_objectives:
                    self.logger.info("Currently discovered objectives:")
                    for idx, obj in enumerate(self.discovered_objectives, 1):
                        self.logger.info(f"  {idx}. {obj}")
                self.logger.info("")

            # Store current objectives for history
            iteration_start_objectives = self.discovered_objectives.copy()
            
            # STEP 1: OBJECTIVES DISCOVERY
            print("\n=== STEP 1: OBJECTIVES DISCOVERY ===")

            # Phase 1: Identify informative samples (returns pre-generated responses for reuse)
            informative_samples, pre_generated_responses = self._identify_informative_samples(self.discovered_objectives)

            # Phase 2: Discover candidate objectives (reuses pre-generated responses)
            candidates = self._discover_candidate_objectives(
                informative_samples,
                self.discovered_objectives,
                pre_generated_responses=pre_generated_responses
            )
            
            self.discovery_stats['total_proposals'] += len(candidates)
            
            # Phase 3: Select best objective
            best_objective, improvement = self._select_best_objective(
                candidates,
                self.discovered_objectives
            )
            
            if best_objective is None:
                print("\nNo valid candidates found in this iteration")
                continue
            
            # STEP 2: OBJECTIVES VERIFICATION
            print("\n=== STEP 2: OBJECTIVES VERIFICATION ===")
            
            is_valid, verification_details = self._verify_objective(best_objective, iteration_candidates=candidates)
            
            if is_valid:
                print(f"\n✓ OBJECTIVE ACCEPTED: {best_objective[:50]}...")
                print(f"  Interpretability score: {verification_details['interpretability_score']:.4f}")
                print(f"  Trend type: {verification_details['trend_type']}")
                print(f"  Trend error: {verification_details['trend_error']:.4f}")
                
                if self.logger:
                    self.logger.info("\n" + "="*60)
                    self.logger.info("✓ OBJECTIVE ACCEPTED")
                    self.logger.info("="*60)
                    self.logger.info(f"Objective: {best_objective}")
                    self.logger.info(f"Verification Summary:")
                    self.logger.info(f"  - Interpretability Score: {verification_details['interpretability_score']:.4f}")
                    self.logger.info(f"  - Trend Type: {verification_details['trend_type']}")
                    self.logger.info(f"  - Trend Error: {verification_details['trend_error']:.4f}")
                    self.logger.info(f"  - Improvement to Obj-Fit: {improvement:.6f}")
                
                self.discovered_objectives.append(best_objective)
                self.current_objectives = self.discovered_objectives.copy()
            else:
                print(f"\n✗ OBJECTIVE REJECTED: {best_objective[:50]}...")
                if not verification_details['interpretable']:
                    print(f"  Failed interpretability check")
                if not verification_details['follows_trend']:
                    print(f"  Failed trend check")
                
                if self.logger:
                    self.logger.info("\n" + "="*60)
                    self.logger.info("✗ OBJECTIVE REJECTED")
                    self.logger.info("="*60)
                    self.logger.info(f"Objective: {best_objective}")
                    self.logger.info(f"Rejection Reasons:")
                    if not verification_details['interpretable']:
                        self.logger.info(f"  - Failed interpretability check (score: {verification_details['interpretability_score']:.4f})")
                    if not verification_details['follows_trend']:
                        self.logger.info(f"  - Failed trend check (error: {verification_details['trend_error']:.4f})")
                
                self.rejected_objectives.append({
                    'objective': best_objective,
                    'reason': verification_details
                })
            
            # Track iteration history
            self.iteration_history.append({
                'iteration': iteration,
                'objectives_start': iteration_start_objectives,
                'objectives_end': self.discovered_objectives.copy(),
                'candidates_evaluated': len(candidates),
                'best_candidate': best_objective,
                'improvement': improvement,
                'accepted': is_valid
            })
            
            # Update objectives history for statistics
            self.objectives_per_iteration.append({
                'iteration': iteration,
                'objectives': self.discovered_objectives.copy(),
                'num_objectives': len(self.discovered_objectives),
                'candidates_evaluated': len(candidates)
            })
        
        # Calculate Final Obj-Error for discovered objectives
        final_obj_error, combiner_save_path = self.calculate_final_obj_error(
            discovered_objectives=self.discovered_objectives,
            ground_truth_reward=self.ground_truth_reward,
            num_samples_eval=self.num_samples_final_eval,  # Using same number of samples as selection
            train_test_split_idx=self.train_test_split_idx,
            combination_function_type=self.combination_function_type,
            combination_function_params=self.combination_function_params,
            save_dir=self.output_dir
        )

        # Calculate final statistics
        self.discovery_stats['time_elapsed'] = time.time() - start_time
        self.discovery_stats['discovered_count'] = len(self.discovered_objectives)
        self.discovery_stats['rejected_count'] = len(self.rejected_objectives)
        self.discovery_stats['acceptance_rate'] = (
            len(self.discovered_objectives) / self.discovery_stats['total_proposals']
            if self.discovery_stats['total_proposals'] > 0 else 0
        )
        self.discovery_stats['objectives_history'] = self.objectives_per_iteration
        self.discovery_stats['iteration_details'] = self.iteration_history
        self.discovery_stats['final_obj_error'] = final_obj_error
        self.discovery_stats['reward_combiner_path'] = combiner_save_path
        
        # Print final summary
        print("\n" + "="*60)
        print("DISCOVERY COMPLETE")
        print("="*60)
        print(f"Valid objectives discovered: {len(self.discovered_objectives)}/{self.k}")
        print(f"Total iterations: {iteration}")
        print(f"Total candidates evaluated: {self.discovery_stats['total_proposals']}")
        print(f"Acceptance rate: {self.discovery_stats['acceptance_rate']:.2%}")
        print(f"Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")
        if final_obj_error is not None:
            print(f"Final Obj-Error: {final_obj_error:.6f}")
        print(f"Verification failures:")
        print(f"  - Interpretability: {self.discovery_stats['verification_failures']['interpretability']}")
        print(f"  - Trend: {self.discovery_stats['verification_failures']['trend']}")
        
        print("\n--- Discovered Objectives ---")
        for i, obj in enumerate(self.discovered_objectives, 1):
            print(f"{i}. {obj}")
        
        if self.logger:
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "DISCOVERY COMPLETE".center(78) + "#")
            self.logger.info("#"*80)
            self.logger.info(f"\nFINAL SUMMARY:")
            self.logger.info(f"  Valid objectives discovered: {len(self.discovered_objectives)}/{self.k}")
            self.logger.info(f"  Total iterations: {iteration}")
            self.logger.info(f"  Total candidates evaluated: {self.discovery_stats['total_proposals']}")
            self.logger.info(f"  Acceptance rate: {self.discovery_stats['acceptance_rate']:.2%}")
            self.logger.info(f"  Time elapsed: {self.discovery_stats['time_elapsed']:.2f} seconds")
            if final_obj_error is not None:
                self.logger.info(f"  Final Obj-Error: {final_obj_error:.6f}")
            self.logger.info(f"\nVerification failures:")
            self.logger.info(f"  - Interpretability: {self.discovery_stats['verification_failures']['interpretability']}")
            self.logger.info(f"  - Trend: {self.discovery_stats['verification_failures']['trend']}")
            
            self.logger.info(f"\nFINAL DISCOVERED OBJECTIVES:")
            for i, obj in enumerate(self.discovered_objectives, 1):
                self.logger.info(f"  {i}. {obj}")
            
            self.logger.info("\n" + "#"*80)
            self.logger.info("#" + "END OF DISCOVERY LOG".center(78) + "#")
            self.logger.info("#"*80)
        
        return self.discovered_objectives, self.discovery_stats