"""
Calculate Objectives Fit (Obj-Fit) for evaluating discovered objectives.

This module implements the Obj-Fit metric from the Methods.md document,
which measures how well a set of discovered objectives approximates
the true reward/loss function used during model training.
"""

import os
import json
import torch
import numpy as np
import asyncio
from typing import List, Dict, Tuple, Optional, Union, Callable, Any
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from openai import OpenAI
from tqdm import tqdm
import random
import time
from collections import OrderedDict
import logging

try:
    # Try relative imports first (when imported as a module)
    from .reward_combiner import create_reward_combiner, RewardCombiner
    from .model_generation import generate_responses_batched, apply_chat_template_to_prompt
    from .constants import OBJECTIVE_SCORING_PROMPT, OPENAI_API_KEY, SCORING_SYSTEM_PROMPT
    from .custom_rewards import RewardFunction
    from .objective_scorer import ObjectiveScorer
    # from .setup_datasets import sample_to_input_dialogue_tldr, sample_to_input_dialogue_hh
    from .setup_datasets_clean import load_normalized_dataset_samples
except ImportError:
    # Fall back to absolute imports (when run directly)
    from reward_combiner import create_reward_combiner, RewardCombiner
    from model_generation import generate_responses_batched, apply_chat_template_to_prompt
    from constants import OBJECTIVE_SCORING_PROMPT, OPENAI_API_KEY, SCORING_SYSTEM_PROMPT
    from custom_rewards import RewardFunction
    from objective_scorer import ObjectiveScorer
    # from setup_datasets import sample_to_input_dialogue_tldr, sample_to_input_dialogue_hh
    from setup_datasets_clean import load_normalized_dataset_samples


class ObjectivesFit:
    """
    Calculate the Obj-Fit metric for a set of discovered objectives.
    
    The Obj-Fit metric measures how well a set of objectives R̂ approximates
    the ground-truth reward/loss function R* using a combination function g.
    
    Obj-Error(R̂, R*) = RMSE = sqrt((1/T) Σ_t E[(r*(x,y) - g(r̂_1(x,y), ..., r̂_k(x,y)))^2])
    where the inner term is the squared residual
    """
    
    def __init__(
        self,
        dataset: Union[str, List[Dict[str, str]]],
        model_sequence: List[str],
        ground_truth_objective: Union[RewardFunction, Callable, str],
        combination_function_type: str = 'linear_regression',
        combination_function_params: Optional[Dict[str, Any]] = None,
        num_samples: int = 100,
        train_test_split_idx: Optional[int] = None,
        scorer_model: str = "gpt-4o-mini",
        device: str = "auto",
        cache_responses: bool = True,
        use_different_prompts: bool = False,
        dataset_type: str = "hh",  # Type of dataset for rubrics
        use_detailed_rubric: bool = True,  # Whether to use detailed rubrics
        batching: bool = True,  # Whether to use batched generation
        # group_scoring: bool = True,  # Whether to use group scoring for objectives
        group_scoring: bool = False,  # Whether to use group scoring for objectives
        batch_size: int = 8,  # Batch size for batched generation
        model_cache_size: int = 3,  # Number of models to keep in memory
        normalize_scores: bool = False,  # Whether to normalize scores to [0, 1]
        save_dir: Optional[str] = None,  # Directory for saving failure logs
        logger: Optional[logging.Logger] = None,  # Optional logger for detailed output
        max_concurrent: int = 20  # Maximum number of concurrent API calls
    ):
        """
        Initialize the Objectives Fit calculator.
        
        Args:
            dataset: Dataset name (HuggingFace) or list of samples
            model_sequence: List of model checkpoint paths [π_θ_1, ..., π_θ_T]
            ground_truth_objective: RewardFunction instance, callable function,
                                  or string 'ppo_reward'/'dpo_loss' for standard metrics
            combination_function_type: Type of g function ('linear', 'linear_regression', 
                                     'gradient_boosting', 'mlp')
            combination_function_params: Parameters for the combination function
            num_samples: Number of dataset samples to use for calculation
            train_test_split_idx: Index to split model sequence (default: T//2)
            scorer_model: Model to use for scoring objectives (GPT-4 or local model)
            device: Device for local models ('auto', 'cuda', 'cpu')
            cache_responses: Whether to cache model responses
            use_different_prompts: Whether to use different prompts for test set
            dataset_type: Type of dataset ('hh' or 'tldr') for appropriate rubrics
            use_detailed_rubric: Whether to use detailed scoring rubrics
            batching: Whether to use batched generation for efficiency (uses vLLM if available)
            batch_size: Batch size for batched generation (default: 8)
            model_cache_size: Number of models to keep in memory simultaneously (default: 3)
            normalize_scores: Whether to normalize scores to [0, 1] range
            logger: Optional logger instance for detailed output
        """
        self.model_sequence = model_sequence
        self.ground_truth_objective = ground_truth_objective
        self.combination_function_type = combination_function_type
        self.combination_function_params = combination_function_params or {}
        self.num_samples = num_samples
        self.scorer_model = scorer_model
        self.device = device
        self.cache_responses = cache_responses
        self.use_different_prompts = use_different_prompts  # Whether to use different prompts for test set
        self.dataset_type = dataset_type
        self.use_detailed_rubric = use_detailed_rubric
        self.batching = batching  # Whether to use batched generation
        self.group_scoring = group_scoring  # Whether to use group scoring for objectives
        self.batch_size = batch_size  # Batch size for batched generation
        self.model_cache_size = model_cache_size  # Number of models to cache
        self.normalize_scores = normalize_scores  # Whether to normalize scores
        self.save_dir = save_dir  # Directory for saving failure logs
        self.logger = logger or logging.getLogger(__name__)  # Use provided logger or default
        self.max_concurrent = max_concurrent  # Maximum number of concurrent API calls

        # Set train-test split index
        self.train_test_split_idx = train_test_split_idx
        if self.train_test_split_idx is None:
            self.train_test_split_idx = len(model_sequence) // 2
        
        # Validate split index
        if not 1 <= self.train_test_split_idx < len(model_sequence):
            raise ValueError(f"Invalid train_test_split_idx: {self.train_test_split_idx}")
        
        # Load and prepare dataset
        self.dataset = self._load_dataset(dataset)
        
        # Sample subset if needed
        # if self.num_samples < len(self.dataset):
        #     self.sampled_dataset = random.sample(self.dataset, self.num_samples)
        # else:
        #     self.sampled_dataset = self.dataset
        self.sampled_train_dataset = self._sample_dataset(self.dataset, self.num_samples)
        self.sampled_test_dataset = self._sample_dataset(self.dataset, self.num_samples)
        
        # Initialize scorer (API or local model)
        self._initialize_scorer()
        
        # Cache for model responses and scores
        self.response_cache = {} if cache_responses else None
        self.score_cache = {} if cache_responses else None
        
        # Multi-model cache using OrderedDict for LRU eviction
        self.model_cache = OrderedDict()  # {model_path: {'model': model, 'tokenizer': tokenizer}}
        self.vllm_cache = OrderedDict()  # {model_path: llm} for vLLM models
        
        # Storage for fitted combination function
        self.combination_function = None
        self.obj_coefficients = None  # Store objective coefficients if applicable

        # Storage for training data (for debugging/analysis)
        self.train_features = None  # Will store normalized training features after collection
        self.train_targets = None  # Will store normalized training targets after collection
        self.unnormalized_train_features = None  # Will store unnormalized training features [1, 10] scale
        self.denormalized_train_targets = None  # Will store denormalized training targets [1, 10] scale
        
        # Check vllm availability for batched generation
        self.vllm_available = False
        if self.batching:
            try:
                import vllm
                self.vllm_available = True
                print("vLLM is available and will be used for batched generation")
            except ImportError:
                print("vLLM not available, falling back to standard batched generation")
        
    def _unnormalize_score(self, score: float, min_val: float = 1.0, max_val: float = 10.0) -> float:
        """
        Unnormalize a score from [0, 1] range back to [min_val, max_val] range.

        Args:
            score: Normalized score in [0, 1] range
            min_val: Minimum value of the target range (default: 1.0)
            max_val: Maximum value of the target range (default: 10.0)

        Returns:
            float: Unnormalized score in [min_val, max_val] range
        """
        return score * (max_val - min_val) + min_val

    def _sample_dataset(self, dataset: List[Dict[str, str]], num_samples: int) -> List[Dict[str, str]]:
        """Sample a subset of the dataset."""
        if num_samples < len(dataset):
            return random.sample(dataset, num_samples)
        else:
            return dataset

    def _load_dataset(self, dataset: Union[str, List[Dict[str, str]]]) -> List[Dict[str, str]]:
        """Load and normalize the dataset."""
        if isinstance(dataset, str):
            normalized_samples = load_normalized_dataset_samples(dataset)
            return normalized_samples

            # # Load from HuggingFace
            # # Load from HuggingFace datasets
            # if dataset == "openai/summarize_from_feedback":
            #     ds = load_dataset(dataset, 'comparisons', split='train')
            #     normalized = []
            #     for sample in ds:
            #         # Format Reddit TLDR data similar to test_ppo_orig.py
            #         info = sample.get("info", {})
            #         # subreddit = info.get("subreddit", "")
            #         # title = info.get("title", "")
            #         # post = info.get("post", "")
                    
            #         # # Create the formatted query string
            #         # query_text = f"SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
                    
            #         # Store as a message format that can be used with chat_template later
            #         normalized_sample = {
            #             # 'input': [{"role": "user", "content": query_text}],
            #             'input': sample_to_input_dialogue_tldr(info),
            #         }
            #         # Keep summaries if present (for reference)
            #         # if 'summaries' in sample:
            #         #     normalized_sample['summaries'] = sample['summaries']
            #         normalized.append(normalized_sample)
            #     return normalized
            # elif dataset == 'Anthropic/hh-rlhf':
            #     ds = load_dataset(dataset, split='train')
            #     normalized = []
            #     for sample in ds:
            #         chosen = sample.get('chosen', '')
            #         # Convert to multi-turn dialogue format
            #         dialogue = sample_to_input_dialogue_hh(chosen, multi_turn=False)
            #         normalized_sample = {
            #             'input': dialogue,
            #         }
            #         normalized.append(normalized_sample)
            #     return normalized
            # else:
            #     raise ValueError(f"Invalid dataset path: {dataset}")
            # raise NotImplementedError("Dataset loading from HuggingFace not implemented yet")
            # ds = load_dataset(dataset, split='train')
            # normalized = []
            # for sample in ds:
            #     normalized.append({
            #         'input': sample.get('prompt', sample.get('input', sample.get('text', '')))
            #     })
            # return normalized
        else:
            # Already a list of samples
            return [{'input': s['input']} for s in dataset]
            # return [{
            #     'input': s.get('input', s.get('prompt', s.get('text', '')))
            # } for s in dataset]
    
    def _initialize_scorer(self):
        """Initialize the ObjectiveScorer for scoring objectives."""
        # Check if scorer_model is already an ObjectiveScorer instance
        if isinstance(self.scorer_model, ObjectiveScorer):
            # Already initialized, just use it
            self.objective_scorer = self.scorer_model
            return
        
        # Otherwise, initialize from model name string
        # Determine if using API or local model
        use_api = self.scorer_model.startswith("gpt")
        
        # Initialize ObjectiveScorer with all necessary parameters
        self.objective_scorer = ObjectiveScorer(
            use_detailed_rubric=self.use_detailed_rubric,
            dataset_type=self.dataset_type,
            use_api=use_api,
            model_name=self.scorer_model,
            device=self.device,
            max_length=4096,  # Use a reasonable default
            load_quantized=not use_api,  # Use quantization for local models
            # cache_file="./custom_rubrics_cache.json",
            cache_dir=None,
            normalize_scores=self.normalize_scores,
            save_dir=self.save_dir  # Directory for failure logs
        )
    
    def calculate(self, predicted_objectives: List[str]) -> float:
        """
        Calculate the Obj-Error (RMSE) metric for the predicted objectives.

        Args:
            predicted_objectives: List of discovered objective descriptions

        Returns:
            Obj-Error value (RMSE - lower is better)
        """
        overall_start_time = time.time()
        self.logger.info(f"\nCalculating Obj-Fit for {len(predicted_objectives)} objectives")
        self.logger.info(f"Using {len(self.model_sequence)} models with split at index {self.train_test_split_idx}")

        # Log cache statistics for efficiency monitoring
        if self.cache_responses and self.response_cache:
            self.logger.info(f"Cache Status: {len(self.response_cache)} responses cached, {len(self.score_cache) if self.score_cache else 0} scores cached")
        
        # Split model sequence
        train_models = self.model_sequence[:self.train_test_split_idx]
        test_models = self.model_sequence[self.train_test_split_idx:]
        
        print(f"Train models: {len(train_models)}, Test models: {len(test_models)}")
        self.logger.info(f"Train models: {len(train_models)}, Test models: {len(test_models)}")
        
        # Step 1: Collect training data
        start_time = time.time()
        print("\n--- Collecting training data ---")
        self.logger.info("\n--- Collecting training data ---")
        train_features, train_targets, denormalized_targets = self._collect_training_data(
            train_models,
            predicted_objectives
        )
        print('--- Training data collection time: %.2f seconds ---' % (time.time() - start_time))
        self.logger.info('--- Training data collection time: %.2f seconds ---' % (time.time() - start_time))

        # Store normalized training data in instance variables
        self.train_features = train_features
        self.train_targets = train_targets
        self.denormalized_train_targets = denormalized_targets

        # Step 2: Fit combination function g
        start_time = time.time()
        print(f"\n--- Fitting combination function ({self.combination_function_type}) ---")
        self.logger.info(f"\n--- Fitting combination function ({self.combination_function_type}) ---")
        self.combination_function, self.obj_coefficients = self._fit_combination_function(
            predicted_objectives,
            train_features,
            train_targets
        )
        print('--- Combination function fitting time: %.2f seconds ---' % (time.time() - start_time))
        self.logger.info('--- Combination function fitting time: %.2f seconds ---' % (time.time() - start_time))

        # Step 2b: Fit denormalized combination function on [1, 10] scale
        start_time = time.time()
        print(f"\n--- Fitting denormalized combination function ({self.combination_function_type}) ---")
        self.logger.info(f"\n--- Fitting denormalized combination function ({self.combination_function_type}) ---")

        # Unnormalize features and targets from [0, 1] to [1, 10]
        unnormalized_train_features = []
        for feature_dict in train_features:
            unnormalized_dict = {
                key: self._unnormalize_score(value, min_val=1.0, max_val=10.0)
                for key, value in feature_dict.items()
            }
            unnormalized_train_features.append(unnormalized_dict)

        # Store unnormalized training features in instance variable
        self.unnormalized_train_features = unnormalized_train_features

        # unnormalized_train_targets = [
        #     self._unnormalize_score(target, min_val=1.0, max_val=10.0)
        #     for target in train_targets
        # ]

        # Fit denormalized combination function
        self.denormalized_combination_function, self.denormalized_obj_coefficients = self._fit_combination_function(
            predicted_objectives,
            unnormalized_train_features,
            denormalized_targets
        )
        print('--- Denormalized combination function fitting time: %.2f seconds ---' % (time.time() - start_time))
        self.logger.info('--- Denormalized combination function fitting time: %.2f seconds ---' % (time.time() - start_time))

        # Step 3: Calculate residuals on test set
        start_time = time.time()
        print("\n--- Calculating residuals on test set ---")
        self.logger.info("\n--- Calculating residuals on test set ---")
        test_residuals = self._calculate_test_residuals(
            test_models,
            predicted_objectives,
            use_different_prompts=self.use_different_prompts
        )
        print('--- Test residuals calculation time: %.2f seconds ---' % (time.time() - start_time))
        self.logger.info('--- Test residuals calculation time: %.2f seconds ---' % (time.time() - start_time))

        # Step 4: Compute Obj-Error (RMSE)
        avg_residual = np.mean(test_residuals)  # This is MSE (Mean Squared Error)
        rmse = np.sqrt(avg_residual)  # Convert MSE to RMSE
        obj_error = rmse  # Return RMSE as the Obj-Error metric

        print(f"\n--- Results ---")
        print(f"Average residual (MSE): {avg_residual:.4f}")
        print(f"Obj-Error (RMSE): {obj_error:.4f}")
        print('--- Total Obj-Fit calculation time: %.2f seconds ---' % (time.time() - overall_start_time))
        self.logger.info(f"\n--- Results ---")
        self.logger.info(f"Average residual (MSE): {avg_residual:.4f}")
        self.logger.info(f"Obj-Error (RMSE): {obj_error:.4f}")
        self.logger.info('--- Total Obj-Fit calculation time: %.2f seconds ---' % (time.time() - overall_start_time))
        
        return obj_error
    
    def _log_sample_details(
        self,
        idx: int,
        input_text_formatted: str,
        response: str,
        objective_scores: Dict[str, float],
        predicted_objectives: List[str],
        ground_truth: float,
        denormalized_ground_truth: float
    ):
        """
        Log detailed information about a sample's scoring results.

        Args:
            idx: Sample index
            input_text_formatted: Formatted input text
            response: Model response
            objective_scores: Dictionary of objective scores
            predicted_objectives: List of objective descriptions
            ground_truth: Normalized ground truth score
            denormalized_ground_truth: Denormalized ground truth score
        """
        self.logger.info(f"\n🔹 Sample {idx+1}:")

        # Log input (truncated for readability)
        input_preview = input_text_formatted[:2000] + "..." if len(input_text_formatted) > 2000 else input_text_formatted
        self.logger.info(f"  Input: {input_preview}")

        # Log response (truncated for readability)
        response_preview = response[:2000] + "..." if len(response) > 2000 else response
        self.logger.info(f"  Response: {response_preview}")

        # Log objective scores
        self.logger.info(f"  Objective Scores:")
        for obj_idx, (obj_key, score) in enumerate(objective_scores.items()):
            obj_name = predicted_objectives[obj_idx] if obj_idx < len(predicted_objectives) else obj_key
            self.logger.info(f"    - {obj_name[:50]}...: {score:.4f}")

        # Log ground truth
        self.logger.info(f"  Ground Truth (normalized): {ground_truth:.4f}")
        self.logger.info(f"  Ground Truth (denormalized): {denormalized_ground_truth:.4f}")

    def _log_test_sample_details(
        self,
        idx: int,
        input_text_formatted: str,
        response: str,
        residual: float,
        predicted_score: Optional[float] = None,
        ground_truth: Optional[float] = None
    ):
        """
        Log detailed information about a test sample's residual calculation.

        Args:
            idx: Sample index
            input_text_formatted: Formatted input text
            response: Model response
            residual: Calculated residual (squared error)
            predicted_score: Optional predicted score from combination function
            ground_truth: Optional ground truth score
        """
        self.logger.info(f"\n🔹 Sample {idx+1}:")

        # Log input (truncated for readability)
        input_preview = input_text_formatted[:2000] + "..." if len(input_text_formatted) > 2000 else input_text_formatted
        self.logger.info(f"  Input: {input_preview}")

        # Log response (truncated for readability)
        response_preview = response[:2000] + "..." if len(response) > 2000 else response
        self.logger.info(f"  Response: {response_preview}")

        # Log residual and its components
        self.logger.info(f"  Residual (squared error): {residual:.6f}")
        self.logger.info(f"  √Residual (absolute error): {np.sqrt(residual):.6f}")

        if predicted_score is not None:
            self.logger.info(f"  Predicted score: {predicted_score:.4f}")
        if ground_truth is not None:
            self.logger.info(f"  Ground truth: {ground_truth:.4f}")

    def _generate_batched_responses(
        self,
        model_path: str,
        sampled_dataset: List[Dict[str, str]],
        batch_size: int
    ) -> List[str]:
        """
        Generate responses for all samples in batched mode with caching.

        This method handles cache checking and batched generation efficiently.

        Args:
            model_path: Path to the model checkpoint
            sampled_dataset: List of dataset samples to generate responses for
            batch_size: Batch size for generation

        Returns:
            List of generated responses
        """
        # Collect all inputs first
        input_texts = [sample['input'] for sample in sampled_dataset]

        # Check cache first to avoid regenerating responses
        responses = []
        uncached_indices = []
        uncached_inputs = []

        if self.response_cache is not None:
            for idx, input_text in enumerate(input_texts):
                # input_text_formatted = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
                input_text_formatted = str(input_text)
                cache_key = f"{model_path}:{input_text_formatted}"

                if cache_key in self.response_cache:
                    # Use cached response
                    responses.append(self.response_cache[cache_key])
                else:
                    # Mark for generation
                    responses.append(None)  # Placeholder
                    uncached_indices.append(idx)
                    uncached_inputs.append(input_text)

            # Only generate uncached responses
            cache_hits = len(input_texts) - len(uncached_inputs)
            if uncached_inputs:
                self.logger.info(f"  Response cache: {cache_hits}/{len(input_texts)} hits, generating {len(uncached_inputs)} new responses")
                new_responses = self._generate_response_batched(model_path, uncached_inputs, batch_size)
                # Fill in the generated responses
                for idx, resp in zip(uncached_indices, new_responses):
                    responses[idx] = resp
            else:
                self.logger.info(f"  Response cache: {cache_hits}/{len(input_texts)} hits (100% cache hit rate)")
        else:
            # No cache enabled, generate all
            responses = self._generate_response_batched(model_path, input_texts, batch_size)

        return responses

    def _get_group_ground_truth(
        self,
        input_text: str,
        responses: List[str],
        denormalize_scores: bool = False
    ) -> List[float]:
        """
        Get ground truth scores for a group of responses to a single query.

        Args:
            input_text: Formatted input query
            responses: List of responses for this query
            denormalize_scores: Whether to return denormalized scores

        Returns:
            List of ground truth scores for each response
        """
        if isinstance(self.ground_truth_objective, RewardFunction):
            # Use group_compute_reward for efficient batch processing
            # Now returns (normalized, denormalized, objective_scores)
            norm_rewards, denorm_rewards, _ = self.ground_truth_objective.group_compute_reward(
                queries=[input_text],
                responses_list=[responses],
                denormalize_scores=denormalize_scores
            )

            # Convert tensor to list of floats based on denormalize_scores flag
            if denormalize_scores:
                ground_truths = [float(reward.item()) for reward in denorm_rewards]
            else:
                ground_truths = [float(reward.item()) for reward in norm_rewards]
            return ground_truths
        else:
            # Fall back to individual scoring for non-RewardFunction ground truths
            ground_truths = []
            for response in responses:
                ground_truth = self._get_ground_truth(input_text, response, denormalize_scores=denormalize_scores)
                ground_truths.append(ground_truth)
            return ground_truths

    def _get_group_ground_truth_both(
        self,
        input_text: Union[str, List[str]],
        responses: Union[List[str], List[List[str]]]
    ) -> Union[Tuple[List[float], List[float]], Tuple[List[List[float]], List[List[float]]]]:
        """
        Get both normalized and denormalized ground truth scores in one call.

        Supports both single query and batch modes:
        - Single mode: input_text is str, responses is List[str]
        - Batch mode: input_text is List[str], responses is List[List[str]]

        Args:
            input_text: Single formatted input query (str) OR list of formatted input queries (List[str])
            responses: List of responses for single query OR list of lists for batch mode

        Returns:
            Single mode: Tuple of (normalized_scores, denormalized_scores) as lists
            Batch mode: Tuple of (all_normalized, all_denormalized) where each is list of lists
        """
        # Detect batch mode: input_text is a list of strings
        is_batch_mode = isinstance(input_text, list)

        if is_batch_mode:
            # Batch mode: input_text is List[str], responses is List[List[str]]
            all_formatted_inputs = input_text
            all_responses_by_query = responses

            if isinstance(self.ground_truth_objective, RewardFunction):
                # Use group_compute_reward with ALL inputs at once for maximum parallelization
                norm_rewards, denorm_rewards, _ = self.ground_truth_objective.group_compute_reward(
                    queries=all_formatted_inputs,
                    responses_list=all_responses_by_query
                )

                # Split flattened results back into per-query lists
                all_normalized = []
                all_denormalized = []
                offset = 0
                for query_responses in all_responses_by_query:
                    num_responses = len(query_responses)
                    query_norm = [float(norm_rewards[offset + i].item()) for i in range(num_responses)]
                    query_denorm = [float(denorm_rewards[offset + i].item()) for i in range(num_responses)]
                    all_normalized.append(query_norm)
                    all_denormalized.append(query_denorm)
                    offset += num_responses

                return all_normalized, all_denormalized
            else:
                # Fall back to sequential per-query scoring
                all_normalized = []
                all_denormalized = []
                for single_input, single_responses in zip(all_formatted_inputs, all_responses_by_query):
                    norm, denorm = self._get_group_ground_truth_both(single_input, single_responses)
                    all_normalized.append(norm)
                    all_denormalized.append(denorm)
                return all_normalized, all_denormalized
        else:
            # Single mode: input_text is str, responses is List[str]
            if isinstance(self.ground_truth_objective, RewardFunction):
                # Use group_compute_reward - now returns both in single call
                norm_rewards, denorm_rewards, _ = self.ground_truth_objective.group_compute_reward(
                    queries=[input_text],
                    responses_list=[responses]
                )

                normalized = [float(r.item()) for r in norm_rewards]
                denormalized = [float(r.item()) for r in denorm_rewards]
                return normalized, denormalized
            else:
                # Fall back to individual scoring
                normalized = []
                denormalized = []
                for response in responses:
                    norm = self._get_ground_truth(input_text, response, denormalize_scores=False)
                    denorm = self._get_ground_truth(input_text, response, denormalize_scores=True)
                    normalized.append(norm)
                    denormalized.append(denorm)
                return normalized, denormalized

    def _process_group_scoring(
        self,
        all_formatted_inputs: List[str],
        all_responses_by_query: List[List[str]],
        predicted_objectives: List[str],
        num_samples_to_log: int,
        num_train_models: int
    ) -> Tuple[List[Dict[str, float]], List[float], List[float]]:
        """
        Process all collected responses using group scoring.

        Args:
            all_formatted_inputs: List of formatted input queries
            all_responses_by_query: List of lists, where each inner list contains responses for one query
            predicted_objectives: List of objective descriptions
            num_samples_to_log: Number of samples to log for debugging
            num_train_models: Number of train models (to filter which responses to include in features/targets)

        Returns:
            Tuple of (features, targets, denormalized_targets) - only for train models
        """
        group_scoring_start_time = time.time()
        features = []
        targets = []
        denormalized_targets = []

        self.logger.info("\n📊 Group scoring all collected responses...")

        for idx, (input_text_formatted, responses_for_query) in enumerate(
            tqdm(zip(all_formatted_inputs, all_responses_by_query),
                 total=len(all_formatted_inputs),
                 desc="Group scoring samples")
        ):
            # Score all responses for each objective using group scoring
            all_scores_by_objective = {}
            for i, obj_desc in enumerate(predicted_objectives):
                # Get all scores for this objective at once
                all_scores_for_obj = self._group_score_with_objective(
                    input_text_formatted,
                    responses_for_query,
                    obj_desc
                )
                all_scores_by_objective[f"obj_{i}"] = all_scores_for_obj # all_scores_for_obj is list of floats
            # breakpoint()

            # Get ground truth using group scoring
            all_ground_truths = self._get_group_ground_truth(
                input_text_formatted,
                responses_for_query,
                denormalize_scores=False
            ) # all_ground_truths is list of floats
            # breakpoint()
            all_denormalized_ground_truths = self._get_group_ground_truth(
                input_text_formatted,
                responses_for_query,
                denormalize_scores=True
            )
            # Process each response but only collect features/targets for train models
            for model_idx, response in enumerate(responses_for_query):
                # Only collect features/targets for train models
                if model_idx < num_train_models:
                    # Collect objective scores for this response
                    objective_scores = {}
                    for i in range(len(predicted_objectives)):
                        objective_scores[f"obj_{i}"] = all_scores_by_objective[f"obj_{i}"][model_idx]

                    ground_truth = all_ground_truths[model_idx]
                    denormalized_ground_truth = all_denormalized_ground_truths[model_idx]

                    features.append(objective_scores)
                    targets.append(ground_truth)
                    denormalized_targets.append(denormalized_ground_truth)

                    # Log first few samples with their scores (only for train models)
                    if idx < num_samples_to_log and model_idx == 0:  # Log only first model for each sample
                        self._log_sample_details(
                            idx,
                            input_text_formatted,
                            response,
                            objective_scores,
                            predicted_objectives,
                            ground_truth,
                            denormalized_ground_truth
                        )

        self.logger.info(f"--- Group scoring time: {time.time() - group_scoring_start_time:.2f} seconds ---")
        print(f"--- Group scoring time: {time.time() - group_scoring_start_time:.2f} seconds ---")
        # breakpoint()

        return features, targets, denormalized_targets

    def _process_group_scoring_async(
        self,
        all_formatted_inputs: List[str],
        all_responses_by_query: List[List[str]],
        predicted_objectives: List[str],
        num_samples_to_log: int,
        num_train_models: int,
        max_concurrent: int = 50,
        save_scores: bool = True
    ) -> Tuple[List[Dict[str, float]], List[float], List[float]]:
        """
        Process all collected responses using fully parallel async scoring.

        Parallelizes ALL objective scoring across ALL samples simultaneously:
        - Total parallel tasks = N samples × K objectives
        - Uses semaphore for rate limiting to avoid overwhelming API

        Uses:
        1. Async parallel objective scoring via asyncio.gather() for ALL (sample, objective) pairs
        2. Single call to get both normalized and denormalized ground truth scores

        Args:
            all_formatted_inputs: List of formatted input queries
            all_responses_by_query: List of lists, where each inner list contains responses for one query
            predicted_objectives: List of objective descriptions
            num_samples_to_log: Number of samples to log for debugging
            num_train_models: Number of train models (to filter which responses to include in features/targets)
            max_concurrent: Maximum number of concurrent API calls (default: 50)

        Returns:
            Tuple of (features, targets, denormalized_targets) - only for train models
        """
        group_scoring_start_time = time.time()

        # Check if we can use async (API mode with async_client available)
        use_async = hasattr(self.objective_scorer, 'async_client') and self.objective_scorer.use_api

        if not use_async:
            # Fallback to original sequential method
            self.logger.info("Async not available, falling back to sequential scoring...")
            return self._process_group_scoring(
                all_formatted_inputs,
                all_responses_by_query,
                predicted_objectives,
                num_samples_to_log,
                num_train_models
            )

        num_samples = len(all_formatted_inputs)
        num_objectives = len(predicted_objectives)
        total_tasks = num_samples * num_objectives

        self.logger.info(f"\n📊 Group scoring {num_samples} samples × {num_objectives} objectives = {total_tasks} tasks in parallel...")
        self.logger.info(f"   Max concurrent requests: {max_concurrent}")

        async def run_all_scoring_parallel():
            """Run ALL (sample, objective) scoring tasks in parallel with rate limiting."""
            semaphore = asyncio.Semaphore(max_concurrent)

            async def score_with_semaphore(sample_idx: int, obj_idx: int, input_text: str, responses: List[str], obj_desc: str):
                """Score a single (sample, objective) pair with rate limiting."""
                async with semaphore:
                    scores = await self._group_score_with_objective_async(input_text, responses, obj_desc)
                    return sample_idx, obj_idx, scores

            # Create ALL scoring tasks for all (sample, objective) combinations
            tasks = []
            for sample_idx, (input_text, responses) in enumerate(zip(all_formatted_inputs, all_responses_by_query)):
                for obj_idx, obj_desc in enumerate(predicted_objectives):
                    task = score_with_semaphore(sample_idx, obj_idx, input_text, responses, obj_desc)
                    tasks.append(task)

            # Run ALL tasks in parallel (rate-limited by semaphore)
            results = await asyncio.gather(*tasks)

            # Reorganize results: scores_by_sample[sample_idx][f"obj_{obj_idx}"] = scores
            scores_by_sample = [{} for _ in range(num_samples)]
            for sample_idx, obj_idx, scores in results:
                scores_by_sample[sample_idx][f"obj_{obj_idx}"] = scores

            return scores_by_sample

        # Run the async scoring
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = None

        scoring_start = time.time()
        if loop and loop.is_running():
            # Already in an async context (e.g., Jupyter notebook)
            print("Running in an async context, applying nest_asyncio...")
            import nest_asyncio
            nest_asyncio.apply()
            all_scores_by_sample = asyncio.get_event_loop().run_until_complete(run_all_scoring_parallel())
        else:
            # Normal Python - just run
            all_scores_by_sample = asyncio.run(run_all_scoring_parallel())

        scoring_time = time.time() - scoring_start
        self.logger.info(f"   Objective scoring completed in {scoring_time:.2f}s ({total_tasks / scoring_time:.1f} tasks/sec)")

        # Get ground truth for all samples using batch mode
        # This enables full parallelization when using LLMRewardFunction with async
        gt_start = time.time()
        self.logger.info(f"   Getting ground truth for {num_samples} samples (batch mode)...")
        all_ground_truths_by_sample, all_denorm_ground_truths_by_sample = self._get_group_ground_truth_both(
            all_formatted_inputs, all_responses_by_query
        )

        gt_time = time.time() - gt_start
        self.logger.info(f"   Ground truth completed in {gt_time:.2f}s")

        # Assemble features and targets
        features = []
        targets = []
        denormalized_targets = []

        # Save all responses_for_query, all_scores_by_objective for debugging (ensure is formatted nicely for debugging)
        if save_scores and self.save_dir:
            import re
            os.makedirs(self.save_dir, exist_ok=True)

            # Create descriptive filename with objective names
            sanitized_objs = [re.sub(r'[^\w\s-]', '', obj).replace(' ', '_')[:50] for obj in predicted_objectives]
            filename = f"scores_{'_'.join(sanitized_objs)}.json"

            debug_data = {
                "objectives": predicted_objectives,
                "samples": []
            }

            for idx, (query, responses, scores_dict) in enumerate(zip(
                all_formatted_inputs, all_responses_by_query, all_scores_by_sample
            )):
                denorm_ground_truths = all_denorm_ground_truths_by_sample[idx]
                sample_data = {
                    "query": query,
                    "responses": [
                        {
                            "response": resp,
                            "scores": {
                                predicted_objectives[i]: scores_dict[f"obj_{i}"][resp_idx]
                                for i in range(len(predicted_objectives))
                            },
                            "ground_truth": denorm_ground_truths[resp_idx]
                        }
                        for resp_idx, resp in enumerate(responses)
                    ]
                }
                debug_data["samples"].append(sample_data)

            filepath = os.path.join(self.save_dir, filename)
            with open(filepath, 'w') as f:
                json.dump(debug_data, f, indent=2)

        for idx, (responses_for_query, all_scores_by_objective) in enumerate(
            zip(all_responses_by_query, all_scores_by_sample)
        ):
            all_ground_truths = all_ground_truths_by_sample[idx]
            all_denorm_ground_truths = all_denorm_ground_truths_by_sample[idx]
            input_text_formatted = all_formatted_inputs[idx]

            # Process each response but only collect features/targets for train models
            for model_idx, response in enumerate(responses_for_query):
                if model_idx < num_train_models:
                    # Collect objective scores for this response
                    objective_scores = {}
                    for i in range(num_objectives):
                        objective_scores[f"obj_{i}"] = all_scores_by_objective[f"obj_{i}"][model_idx]

                    ground_truth = all_ground_truths[model_idx]
                    denormalized_ground_truth = all_denorm_ground_truths[model_idx]

                    features.append(objective_scores)
                    targets.append(ground_truth)
                    denormalized_targets.append(denormalized_ground_truth)

                    # Log first few samples with their scores (only for train models)
                    if idx < num_samples_to_log and model_idx == 0:
                        self._log_sample_details(
                            idx,
                            input_text_formatted,
                            response,
                            objective_scores,
                            predicted_objectives,
                            ground_truth,
                            denormalized_ground_truth
                        )

        total_time = time.time() - group_scoring_start_time
        self.logger.info(f"--- Group scoring time (fully parallel): {total_time:.2f} seconds ---")
        print(f"--- Group scoring time (fully parallel): {total_time:.2f} seconds ---")

        return features, targets, denormalized_targets

    def _process_single_scoring_async(
        self,
        all_formatted_inputs: List[str],
        all_responses: List[str],
        predicted_objectives: List[str],
        num_samples_to_log: int,
        max_concurrent: int = 50,
        save_scores: bool = False
    ) -> Tuple[List[Dict[str, float]], List[float], List[float]]:
        """
        Process all collected (query, response) pairs using fully parallel async scoring.

        This is the async parallel version for non-group scoring, similar to
        _process_group_scoring_async but for single query-response pairs.

        Parallelizes ALL objective scoring across ALL samples simultaneously:
        - Total parallel tasks = N samples × K objectives
        - Uses semaphore for rate limiting to avoid overwhelming API

        Args:
            all_formatted_inputs: List of formatted input queries (one per sample)
            all_responses: List of responses (one per sample, matching inputs)
            predicted_objectives: List of objective descriptions
            num_samples_to_log: Number of samples to log for debugging
            max_concurrent: Maximum number of concurrent API calls (default: 50)
            save_scores: Whether to save scores to a JSON file for debugging (default: True)

        Returns:
            Tuple of (features, targets, denormalized_targets)
        """
        scoring_start_time = time.time()

        # Check if we can use async (API mode with async_client available)
        use_async = hasattr(self.objective_scorer, 'async_client') and self.objective_scorer.use_api

        if not use_async:
            # Fallback to sequential scoring
            self.logger.info("Async not available, falling back to sequential single scoring...")
            features = []
            targets = []
            denormalized_targets = []

            for idx, (input_text_formatted, response) in enumerate(
                tqdm(zip(all_formatted_inputs, all_responses),
                     total=len(all_formatted_inputs),
                     desc="Sequential single scoring")
            ):
                # Score with each objective
                objective_scores = {}
                for i, obj_desc in enumerate(predicted_objectives):
                    score = self._score_with_objective(input_text_formatted, response, obj_desc)
                    objective_scores[f"obj_{i}"] = score

                # Get ground truth
                ground_truth = self._get_ground_truth(input_text_formatted, response)
                denormalized_ground_truth = self._get_ground_truth(
                    input_text_formatted, response, denormalize_scores=True
                )

                features.append(objective_scores)
                targets.append(ground_truth)
                denormalized_targets.append(denormalized_ground_truth)

                # Log first few samples
                if idx < num_samples_to_log:
                    self._log_sample_details(
                        idx,
                        input_text_formatted,
                        response,
                        objective_scores,
                        predicted_objectives,
                        ground_truth,
                        denormalized_ground_truth
                    )

            return features, targets, denormalized_targets

        num_samples = len(all_formatted_inputs)
        num_objectives = len(predicted_objectives)
        total_tasks = num_samples * num_objectives

        self.logger.info(f"\n📊 Single scoring {num_samples} samples × {num_objectives} objectives = {total_tasks} tasks in parallel...")
        self.logger.info(f"   Max concurrent requests: {max_concurrent}")
        print(f"--- Single scoring {num_samples} samples × {num_objectives} objectives = {total_tasks} tasks in parallel...")
        print(f"   Max concurrent requests: {max_concurrent}")

        async def run_all_scoring_parallel():
            """Run ALL (sample, objective) scoring tasks in parallel with rate limiting."""
            semaphore = asyncio.Semaphore(max_concurrent)

            async def score_with_semaphore(sample_idx: int, obj_idx: int, input_text: str, response: str, obj_desc: str):
                """Score a single (sample, objective) pair with rate limiting."""
                async with semaphore:
                    score = await self._score_with_objective_async(input_text, response, obj_desc)
                    return sample_idx, obj_idx, score

            # Create ALL scoring tasks for all (sample, objective) combinations
            tasks = []
            for sample_idx, (input_text, response) in enumerate(zip(all_formatted_inputs, all_responses)):
                for obj_idx, obj_desc in enumerate(predicted_objectives):
                    task = score_with_semaphore(sample_idx, obj_idx, input_text, response, obj_desc)
                    tasks.append(task)

            # Run ALL tasks in parallel (rate-limited by semaphore)
            results = await asyncio.gather(*tasks)

            # Reorganize results: scores_by_sample[sample_idx][f"obj_{obj_idx}"] = score
            scores_by_sample = [{} for _ in range(num_samples)]
            for sample_idx, obj_idx, score in results:
                scores_by_sample[sample_idx][f"obj_{obj_idx}"] = score

            return scores_by_sample

        # Run the async scoring
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = None

        if loop and loop.is_running():
            # Already in an async context (e.g., Jupyter notebook)
            print("Running in an async context, applying nest_asyncio...")
            import nest_asyncio
            nest_asyncio.apply()
            all_scores_by_sample = asyncio.get_event_loop().run_until_complete(run_all_scoring_parallel())
        else:
            # Normal Python - just run
            all_scores_by_sample = asyncio.run(run_all_scoring_parallel())

        scoring_time = time.time() - scoring_start_time
        self.logger.info(f"   Objective scoring completed in {scoring_time:.2f}s ({total_tasks / scoring_time:.1f} tasks/sec)")

        # Get ground truth for all samples using batched compute_reward
        # This is more efficient than calling _get_ground_truth for each sample individually
        gt_start = time.time()
        self.logger.info(f"   Getting ground truth for {num_samples} samples...")

        if isinstance(self.ground_truth_objective, RewardFunction):
            # Use batched compute_reward which automatically uses async if available
            norm_rewards, denorm_rewards, _ = self.ground_truth_objective.compute_reward(
                queries=all_formatted_inputs,
                responses=all_responses
            )
            all_ground_truths = [float(r.item()) for r in norm_rewards]
            all_denorm_ground_truths = [float(r.item()) for r in denorm_rewards]
        else:
            # Fallback to individual scoring for non-RewardFunction ground truths
            all_ground_truths = []
            all_denorm_ground_truths = []
            for input_text, response in zip(all_formatted_inputs, all_responses):
                all_ground_truths.append(self._get_ground_truth(input_text, response))
                all_denorm_ground_truths.append(self._get_ground_truth(input_text, response, denormalize_scores=True))

        gt_time = time.time() - gt_start
        self.logger.info(f"   Ground truth completed in {gt_time:.2f}s")

        # Assemble features and targets
        features = []
        targets = []
        denormalized_targets = []

        # Save all responses, scores, and ground truths for debugging
        if save_scores and self.save_dir:
            import re
            os.makedirs(self.save_dir, exist_ok=True)

            # Create descriptive filename with objective names
            sanitized_objs = [re.sub(r'[^\w\s-]', '', obj).replace(' ', '_')[:50] for obj in predicted_objectives]
            filename = f"single_scores_{'_'.join(sanitized_objs)}.json"

            debug_data = {
                "objectives": predicted_objectives,
                "samples": []
            }

            for idx in range(num_samples):
                sample_data = {
                    "query": all_formatted_inputs[idx],
                    "response": all_responses[idx],
                    "scores": {
                        predicted_objectives[i]: all_scores_by_sample[idx][f"obj_{i}"]
                        for i in range(num_objectives)
                    },
                    "ground_truth": all_denorm_ground_truths[idx]
                }
                debug_data["samples"].append(sample_data)

            filepath = os.path.join(self.save_dir, filename)
            with open(filepath, 'w') as f:
                json.dump(debug_data, f, indent=2)

        for idx, objective_scores in enumerate(all_scores_by_sample):
            ground_truth = all_ground_truths[idx]
            denormalized_ground_truth = all_denorm_ground_truths[idx]

            features.append(objective_scores)
            targets.append(ground_truth)
            denormalized_targets.append(denormalized_ground_truth)

            # Log first few samples with their scores
            if idx < num_samples_to_log:
                self._log_sample_details(
                    idx,
                    all_formatted_inputs[idx],
                    all_responses[idx],
                    objective_scores,
                    predicted_objectives,
                    ground_truth,
                    denormalized_ground_truth
                )

        total_time = time.time() - scoring_start_time
        self.logger.info(f"--- Single scoring time (fully parallel): {total_time:.2f} seconds ---")
        print(f"--- Single scoring time (fully parallel): {total_time:.2f} seconds ---")

        return features, targets, denormalized_targets

    # OLD: Per-sample parallel scoring (kept for reference)
    # def _score_all_objectives_parallel(
    #     self,
    #     input_text: str,
    #     responses: List[str],
    #     objectives: List[str]
    # ) -> Dict[str, List[float]]:
    #     """
    #     Score all objectives in parallel using async.
    #     """
    #     async def score_objectives_async():
    #         coroutines = [
    #             self._group_score_with_objective_async(input_text, responses, obj_desc)
    #             for obj_desc in objectives
    #         ]
    #         results = await asyncio.gather(*coroutines)
    #         return results
    #
    #     try:
    #         loop = asyncio.get_running_loop()
    #     except RuntimeError:
    #         loop = None
    #
    #     if loop and loop.is_running():
    #         import nest_asyncio
    #         nest_asyncio.apply()
    #         results = asyncio.get_event_loop().run_until_complete(score_objectives_async())
    #     else:
    #         results = asyncio.run(score_objectives_async())
    #
    #     all_scores_by_objective = {}
    #     for i, scores in enumerate(results):
    #         all_scores_by_objective[f"obj_{i}"] = scores
    #
    #     return all_scores_by_objective

    def _group_score_with_objective(
        self,
        input_text: str,
        responses: List[str],
        objective: str
    ) -> List[float]:
        """
        Score a group of responses for a single objective using group scoring.

        Args:
            input_text: Formatted input query
            responses: List of responses for this query
            objective: Single objective description

        Returns:
            List of scores for each response
        """
        # Use group scoring for all responses at once
        scores = self.objective_scorer.group_score_single_objective(
            query=input_text,
            responses=responses,
            objective=objective
        )

        # Handle normalization if needed
        if self.ground_truth_objective.normalize_scores and not self.objective_scorer.normalize_scores:
            # Normalize scores to match ground truth scale
            scores = [self.ground_truth_objective._normalize_score(score) for score in scores]

        return scores

    async def _group_score_with_objective_async(
        self,
        input_text: str,
        responses: List[str],
        objective: str
    ) -> List[float]:
        """
        Async version: Score a group of responses for a single objective using group scoring.

        Args:
            input_text: Formatted input query
            responses: List of responses for this query
            objective: Single objective description

        Returns:
            List of scores for each response
        """
        # Use async group scoring for all responses at once
        scores = await self.objective_scorer.async_group_score_single_objective(
            query=input_text,
            responses=responses,
            objective=objective
        )

        # Handle normalization if needed
        if self.ground_truth_objective.normalize_scores and not self.objective_scorer.normalize_scores:
            scores = [self.ground_truth_objective._normalize_score(score) for score in scores]

        return scores

    def _collect_training_data(
        self,
        train_models: List[str],
        predicted_objectives: List[str]
    ) -> Tuple[List[Dict[str, float]], List[float]]:
        """
        Collect training data from the training portion of model sequence.

        Returns:
            Tuple of (features, targets) where:
            - features: List of dicts mapping objective names to scores
            - targets: List of ground truth values
        """
        features = []
        targets = []
        denormalized_targets = []
        # sampled_dataset = self._sample_dataset(self.dataset, self.num_samples)
        sampled_dataset = self.sampled_train_dataset

        # For group scoring, collect all responses from both train and test models
        if self.group_scoring:
            all_responses_by_query = [[] for _ in sampled_dataset]
            all_formatted_inputs = []
            # Get all models (train + test) for group context
            test_models = self.model_sequence[self.train_test_split_idx:]
            all_models = train_models + test_models
            models_to_iterate = all_models
            self.logger.info(f"\n📊 Group scoring mode: Will collect responses from all {len(all_models)} models")
            self.logger.info(f"  - {len(train_models)} train models (for training)")
            self.logger.info(f"  - {len(test_models)} test models (for context)")
        else:
            models_to_iterate = train_models
            # For single (non-group) scoring, also collect all pairs for async processing
            all_formatted_inputs_single = []
            all_responses_single = []

        # Log sampled dataset information
        self.logger.info("="*80)
        self.logger.info("📊 TRAINING DATA COLLECTION")
        self.logger.info("="*80)
        self.logger.info(f"Sampled dataset size: {len(sampled_dataset)} samples")
        self.logger.info(f"Number of training models: {len(train_models)}")
        self.logger.info(f"Number of predicted objectives: {len(predicted_objectives)}")
        self.logger.info(f"Group scoring enabled: {self.group_scoring}")
        self.logger.info(f"Predicted objectives:")
        for i, obj in enumerate(predicted_objectives):
            self.logger.info(f"  {i+1}. {obj}")
        self.logger.info("-"*80)

        for model_idx, model_path in enumerate(models_to_iterate):
            # Determine if this is a train or test model
            if self.group_scoring:
                is_train_model = model_idx < len(train_models)
                model_type = "train" if is_train_model else "test"
                local_idx = model_idx if is_train_model else model_idx - len(train_models)
                print(f"Collecting responses from {model_type} model {local_idx + 1}/{len(train_models) if is_train_model else len(test_models)}")
                self.logger.info(f"\n🔧 Collecting from {model_type} model {local_idx + 1}: {model_path}")
            else:
                # Non-group scoring mode - only iterating over train models
                t = model_idx
                print(f"Processing train model {t+1}/{len(train_models)}")
                self.logger.info(f"\n🔧 Processing training model {t+1}/{len(train_models)}: {model_path}")

            # Log first few samples for this model
            num_samples_to_log = min(50, len(self.sampled_train_dataset))
            sample_logs = []  # Collect logs for batch display

            if self.batching:
                # Batched generation mode - use the new modular method
                responses = self._generate_batched_responses(model_path, sampled_dataset, self.batch_size)

                # Log first few input-response pairs
                if self.group_scoring:
                    if is_train_model:
                        self.logger.info(f"\n📝 First {num_samples_to_log} input-response pairs for train model {local_idx+1}:")
                    else:
                        self.logger.info(f"\n📝 First {num_samples_to_log} input-response pairs for test model {local_idx+1}:")
                else:
                    self.logger.info(f"\n📝 First {num_samples_to_log} input-response pairs for model {model_idx+1}:")
                self.logger.info("-"*60)

                # Process each sample with its response
                for idx, (sample, response) in enumerate(tqdm(zip(sampled_dataset, responses),
                                            total=len(sampled_dataset),
                                            desc="Processing samples")):
                    input_text = sample['input']
                    # Properly format multi-turn conversations using chat template
                    input_text_formatted = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

                    # If group scoring, just collect responses
                    if self.group_scoring:
                        all_responses_by_query[idx].append(response)
                        if model_idx == 0:  # Only collect formatted inputs once
                            all_formatted_inputs.append(input_text_formatted)
                    else:
                        # Collect pairs for async scoring (will be scored after all models processed)
                        all_formatted_inputs_single.append(input_text_formatted)
                        all_responses_single.append(response)
            else:
                # Original non-batched mode
                # Log first few samples for this model
                if self.group_scoring:
                    if is_train_model:
                        self.logger.info(f"\n📝 First {num_samples_to_log} input-response pairs for train model {local_idx+1}:")
                    else:
                        self.logger.info(f"\n📝 First {num_samples_to_log} input-response pairs for test model {local_idx+1}:")
                else:
                    self.logger.info(f"\n📝 First {num_samples_to_log} input-response pairs for model {model_idx+1}:")
                self.logger.info("-"*60)

                # for sample in tqdm(self.sampled_dataset, desc="Samples"):
                for idx, sample in enumerate(tqdm(sampled_dataset, desc="Samples")):
                    input_text = sample['input']
                    # Properly format multi-turn conversations using chat template
                    input_text_formatted = apply_chat_template_to_prompt(self.model_sequence[0], input_text)

                    # Generate response (will use cache automatically in _generate_response)
                    response = self._generate_response(model_path, input_text)

                    # If group scoring, just collect responses
                    if self.group_scoring:
                        all_responses_by_query[idx].append(response)
                        if model_idx == 0:  # Only collect formatted inputs once
                            all_formatted_inputs.append(input_text_formatted)
                    else:
                        # Collect pairs for async scoring (will be scored after all models processed)
                        all_formatted_inputs_single.append(input_text_formatted)
                        all_responses_single.append(response)

        # After collecting all data, score all collected pairs
        if self.group_scoring:
            num_samples_to_log = min(50, len(sampled_dataset))
            # features, targets, denormalized_targets = self._process_group_scoring(
            #     all_formatted_inputs,
            #     all_responses_by_query,
            #     predicted_objectives,
            #     num_samples_to_log,
            #     len(train_models)  # Pass number of train models
            # )
            features, targets, denormalized_targets = self._process_group_scoring_async(
                all_formatted_inputs,
                all_responses_by_query,
                predicted_objectives,
                num_samples_to_log,
                len(train_models),  # Pass number of train models
                max_concurrent=self.max_concurrent
            )
        else:
            # Non-group scoring: use async parallel scoring for all collected pairs
            num_samples_to_log = min(50, len(all_formatted_inputs_single))
            features, targets, denormalized_targets = self._process_single_scoring_async(
                all_formatted_inputs_single,
                all_responses_single,
                predicted_objectives,
                num_samples_to_log,
                max_concurrent=self.max_concurrent
            )

        return features, targets, denormalized_targets
    
    def _fit_combination_function(
        self,
        predicted_objectives: List[str],
        train_features: List[Dict[str, float]],
        train_targets: List[float]
    ) -> Tuple[RewardCombiner, Optional[Dict[str, Any]]]:
        """
        Fit the combination function g on training data.

        Returns:
            Tuple of (combiner, coefficients_dict) where coefficients_dict contains
            'coefficients' and 'intercept' for linear regression models, or None.
        """
        # Create objective names
        objective_names = [f"obj_{i}" for i in range(len(predicted_objectives))]

        # Create combination function
        combiner = create_reward_combiner(
            combiner_type=self.combination_function_type,
            objective_names=objective_names,
            **self.combination_function_params
        )
        # Train if trainable
        if hasattr(combiner.combination_function, 'fit'):
            combiner.combination_function.fit(train_features, train_targets)
            print(f"Fitted {self.combination_function_type} combination function")

        # Check if linear regression and extract coefficients
        coefficients_dict = None
        if self.combination_function_type == 'linear_regression':
            # Access the LinearRegressionFunction's sklearn model
            if hasattr(combiner.combination_function, 'model') and hasattr(combiner.combination_function.model, 'coef_'):
                lr_model = combiner.combination_function.model
                coefficients = lr_model.coef_
                intercept = lr_model.intercept_

                # Create coefficient mapping with objective names
                coef_mapping = {}
                for i, obj_name in enumerate(predicted_objectives):
                    coef_mapping[obj_name] = float(coefficients[i])

                coefficients_dict = {
                    'coefficients': coef_mapping,
                    'intercept': float(intercept),
                    'raw_coefficients': coefficients.tolist()
                }

                # Log the coefficients
                print(f"\n--- Linear Regression Coefficients ---")
                print(f"Intercept: {intercept:.4f}")
                for obj_name, coef in coef_mapping.items():
                    print(f"  {obj_name}: {coef:.4f}")
                print("--------------------------------------\n")

                self.logger.info(f"\n--- Linear Regression Coefficients ---")
                self.logger.info(f"Intercept: {intercept:.4f}")
                for obj_name, coef in coef_mapping.items():
                    self.logger.info(f"  {obj_name}: {coef:.4f}")
                self.logger.info("--------------------------------------\n")

        return combiner, coefficients_dict
    
    def _calculate_test_residuals(
        self,
        test_models: List[str],
        predicted_objectives: List[str],
        use_different_prompts: bool = False
    ) -> List[float]:
        """
        Calculate residuals on the test portion of model sequence.

        The residuals are calculated over the same prompts (x) as in training, however the responses are different.
        This primarily tests the temporal generalization of the combination function g.

        When group_scoring is enabled, we gather responses from all models (train and test)
        to provide context for group scoring, but only calculate residuals for test models.
        """

        residuals = []
        sampled_dataset = self.sampled_test_dataset

        start_time = time.time()

        # Log test dataset information
        self.logger.info("="*80)
        self.logger.info("📊 TEST DATA RESIDUALS CALCULATION")
        self.logger.info("="*80)
        self.logger.info(f"Sampled dataset size: {len(sampled_dataset)} samples")
        self.logger.info(f"Number of test models: {len(test_models)}")
        self.logger.info(f"Use different prompts: {use_different_prompts}")
        self.logger.info(f"Group scoring enabled: {self.group_scoring}")
        self.logger.info("-"*80)

        # If group scoring is enabled, first collect all responses from both train and test models
        if self.group_scoring:
            # Get train models for context
            train_models = self.model_sequence[:self.train_test_split_idx]
            all_models = train_models + test_models

            self.logger.info(f"\n📊 Group scoring mode: Collecting responses from all {len(all_models)} models...")
            self.logger.info(f"  - {len(train_models)} train models (for context)")
            self.logger.info(f"  - {len(test_models)} test models (for evaluation)")

            # Collect all responses for group scoring
            all_responses_by_query = [[] for _ in sampled_dataset]
            all_formatted_inputs = []

            # Collect responses from ALL models (train + test) for group context
            for model_idx, model_path in enumerate(all_models):
                is_test_model = model_idx >= len(train_models)
                model_type = "test" if is_test_model else "train"
                local_idx = model_idx - len(train_models) if is_test_model else model_idx

                print(f"Collecting responses from {model_type} model {local_idx + 1}/{len(test_models) if is_test_model else len(train_models)}")
                self.logger.info(f"\n🔧 Collecting from {model_type} model {local_idx + 1}: {model_path}")

                if self.batching:
                    # Use modular batched response generation
                    responses = self._generate_batched_responses(model_path, sampled_dataset, self.batch_size)
                else:
                    # Non-batched mode
                    responses = []
                    for sample in sampled_dataset:
                        response = self._generate_response(model_path, sample['input'])
                        responses.append(response)

                # Store responses
                for idx, response in enumerate(responses):
                    all_responses_by_query[idx].append(response)
                    if model_idx == 0:  # Only collect formatted inputs once
                        input_text_formatted = apply_chat_template_to_prompt(self.model_sequence[0], sampled_dataset[idx]['input'])
                        all_formatted_inputs.append(input_text_formatted)

            # Now calculate residuals using batch parallel scoring
            self.logger.info("\n📊 Calculating residuals for test models using batch parallel scoring...")

            # Build test model indices (positions in the response lists)
            test_model_indices = [len(train_models) + i for i in range(len(test_models))]

            # Calculate all residuals in parallel batch
            batch_results = self._calculate_residuals_group_batch(
                all_formatted_inputs,
                all_responses_by_query,
                predicted_objectives,
                test_model_indices,
                max_concurrent=self.max_concurrent
            )

            # Results are ordered as: [(sample_0, test_0), (sample_0, test_1), ..., (sample_1, test_0), ...]
            # Extract residuals and log first few samples
            num_samples_to_log = min(50, len(sampled_dataset))
            result_idx = 0

            for sample_idx in range(len(sampled_dataset)):
                for test_idx, test_model_idx in enumerate(test_model_indices):
                    residual, predicted_score, ground_truth = batch_results[result_idx]
                    residuals.append(residual)
                    result_idx += 1

                    # Log first few samples (only for first test model)
                    if sample_idx < num_samples_to_log and test_idx == 0:
                        input_text_formatted = all_formatted_inputs[sample_idx]
                        response = all_responses_by_query[sample_idx][test_model_idx]
                        self._log_test_sample_details(
                            sample_idx,
                            input_text_formatted,
                            response,
                            residual,
                            predicted_score=predicted_score,
                            ground_truth=ground_truth
                        )

            # OLD: Sequential residual calculation (replaced by batch parallel above)
            # self.logger.info("\n📊 Calculating residuals for test models using group scoring...")
            # num_samples_to_log = min(50, len(sampled_dataset))
            #
            # # Process each test model's responses
            # for test_idx, test_model_path in enumerate(test_models):
            #     model_idx_in_all = len(train_models) + test_idx  # Index in the combined list
            #     self.logger.info(f"\n🔧 Evaluating test model {test_idx + 1}/{len(test_models)}: {test_model_path}")
            #
            #     for sample_idx, (input_text_formatted, all_responses_for_query) in enumerate(
            #         zip(all_formatted_inputs, all_responses_by_query)
            #     ):
            #         # Get the specific response from this test model
            #         response = all_responses_for_query[model_idx_in_all]
            #
            #         # Calculate residual using group scoring
            #         residual, predicted_score, ground_truth = self._calculate_residual_group(
            #             input_text_formatted,
            #             response,
            #             all_responses_for_query,
            #             predicted_objectives,
            #             model_idx_in_all
            #         )
            #         residuals.append(residual)
            #
            #         # Log first few samples
            #         if sample_idx < num_samples_to_log and test_idx == 0:  # Log only for first test model
            #             self._log_test_sample_details(
            #                 sample_idx,
            #                 input_text_formatted,
            #                 response,
            #                 residual,
            #                 predicted_score=predicted_score,
            #                 ground_truth=ground_truth
            #             )
        else:
            # Non-group-scoring logic with async parallel scoring
            # First collect all (input, response) pairs from all test models
            all_formatted_inputs_single = []
            all_responses_single = []
            # Track model indices for each pair (for logging)
            model_indices = []

            for t, model_path in enumerate(test_models):
                print(f"Collecting responses from test model {t+1}/{len(test_models)}")
                self.logger.info(f"\n🔧 Collecting from test model {t+1}/{len(test_models)}: {model_path}")

                if self.batching:
                    # Batched generation mode
                    responses = self._generate_batched_responses(model_path, sampled_dataset, self.batch_size)

                    # Collect pairs
                    for idx, (sample, response) in enumerate(zip(sampled_dataset, responses)):
                        input_text = sample['input']
                        input_text_formatted = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
                        all_formatted_inputs_single.append(input_text_formatted)
                        all_responses_single.append(response)
                        model_indices.append(t)
                else:
                    # Non-batched mode
                    for idx, sample in enumerate(tqdm(sampled_dataset, desc="Generating responses")):
                        input_text = sample['input']
                        input_text_formatted = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
                        response = self._generate_response(model_path, input_text)
                        all_formatted_inputs_single.append(input_text_formatted)
                        all_responses_single.append(response)
                        model_indices.append(t)

            # Now calculate all residuals in parallel using batch method
            self.logger.info("\n📊 Calculating residuals for test models using batch parallel scoring...")
            batch_results = self._calculate_residuals_single_batch(
                all_formatted_inputs_single,
                all_responses_single,
                predicted_objectives,
                max_concurrent=self.max_concurrent
            )

            # Extract residuals and log first few samples
            num_samples_to_log = min(50, len(sampled_dataset))
            num_samples_per_model = len(sampled_dataset)

            for result_idx, (residual, predicted_score, ground_truth) in enumerate(batch_results):
                residuals.append(residual)

                # Determine which model and sample this result came from
                model_idx = model_indices[result_idx]
                sample_idx = result_idx % num_samples_per_model

                # Log first few samples (only for first test model)
                if sample_idx < num_samples_to_log and model_idx == 0:
                    self._log_test_sample_details(
                        sample_idx,
                        all_formatted_inputs_single[result_idx],
                        all_responses_single[result_idx],
                        residual,
                        predicted_score=predicted_score,
                        ground_truth=ground_truth
                    )

        # return residuals, obj_values_list
        self.logger.info(f"--- Test data residuals calculated in {time.time() - start_time:.2f} seconds ---")
        print(f"--- Test data residuals calculated in {time.time() - start_time:.2f} seconds ---")
        return residuals
    
    def _calculate_residual(
        self,
        input_text: str,
        response: str,
        predicted_objectives: List[str]
    ) -> float:
        """
        Calculate residual for a single (x, y) pair.
        
        Residual(x, y, R̂, r*) = (r*(x,y) - g(r̂_1(x,y), ..., r̂_k(x,y)))^2
        """
        # Get ground truth
        ground_truth = self._get_ground_truth(input_text, response)
        
        # Score with each objective
        objective_scores = {}
        for i, obj_desc in enumerate(predicted_objectives):
            score = self._score_with_objective(input_text, response, obj_desc)
            objective_scores[f"obj_{i}"] = score
        
        # Get combined prediction
        if self.combination_function is None:
            raise ValueError("Combination function not fitted yet")
        
        predicted = self.combination_function.combine_rewards(objective_scores)
        
        # Calculate squared error
        residual = (ground_truth - predicted) ** 2
        
        # return residual, (predicted, ground_truth)
        return residual

    def _calculate_residual_group(
        self,
        input_text: str,
        response: str,
        all_responses: List[str],
        predicted_objectives: List[str],
        response_index: int
    ) -> Tuple[float, float, float]:
        """
        Calculate residual for a single response using group scoring context.

        This version uses group scoring where all responses for the same query
        are scored together, allowing the scorer to see the trajectory of responses.

        Args:
            input_text: Formatted input query
            response: The specific response to calculate residual for
            all_responses: All responses for this query (for group context)
            predicted_objectives: List of objective descriptions
            response_index: Index of the current response in all_responses

        Returns:
            Tuple of (residual, predicted_score, ground_truth)
        """
        # Get ground truth for all responses using group scoring for better calibration
        all_ground_truths = self._get_group_ground_truth(
            input_text,
            all_responses,
            denormalize_scores=False
        )

        # Extract the ground truth for our specific response
        ground_truth = all_ground_truths[response_index]

        # Score all responses together for each objective (group scoring)
        objective_scores = {}
        for i, obj_desc in enumerate(predicted_objectives):
            # Get scores for all responses using group scoring
            all_scores = self._group_score_with_objective(
                input_text,
                all_responses,
                obj_desc
            )
            # Extract the score for our specific response
            objective_scores[f"obj_{i}"] = all_scores[response_index]

        # Get combined prediction
        if self.combination_function is None:
            raise ValueError("Combination function not fitted yet")

        predicted = self.combination_function.combine_rewards(objective_scores)

        # Calculate squared error
        residual = (ground_truth - predicted) ** 2

        return residual, predicted, ground_truth

    def _calculate_residuals_group_batch(
        self,
        all_formatted_inputs: List[str],
        all_responses_by_query: List[List[str]],
        predicted_objectives: List[str],
        test_model_indices: List[int],
        max_concurrent: int = 50
    ) -> List[Tuple[float, float, float]]:
        """
        Calculate residuals for test models using parallel batch scoring.

        This method efficiently computes residuals by:
        1. Scoring all (sample, objective) pairs in parallel using async
        2. Getting ground truth for all samples in batch mode
        3. Calculating residuals for each (sample, test_model) combination

        Args:
            all_formatted_inputs: All formatted input queries
            all_responses_by_query: All responses for each query (from all models)
            predicted_objectives: Objectives to score
            test_model_indices: Indices in each response list that correspond to test models
            max_concurrent: Maximum concurrent API calls

        Returns:
            List of (residual, predicted_score, ground_truth) tuples for each (sample, test_model)
        """
        num_samples = len(all_formatted_inputs)
        num_objectives = len(predicted_objectives)
        num_test_models = len(test_model_indices)

        # Check if async is available
        use_async = hasattr(self.objective_scorer, 'async_client') and self.objective_scorer.use_api

        if not use_async:
            # Fallback to sequential scoring
            self.logger.info("Async not available, using sequential residual calculation...")
            results = []
            for sample_idx in range(num_samples):
                input_text = all_formatted_inputs[sample_idx]
                all_responses = all_responses_by_query[sample_idx]
                for test_idx in test_model_indices:
                    response = all_responses[test_idx]
                    residual, predicted, ground_truth = self._calculate_residual_group(
                        input_text, response, all_responses, predicted_objectives, test_idx
                    )
                    results.append((residual, predicted, ground_truth))
            return results

        total_tasks = num_samples * num_objectives
        self.logger.info(f"\n📊 Batch residual calculation: {num_samples} samples × {num_objectives} objectives = {total_tasks} tasks")
        self.logger.info(f"   Test models to evaluate: {num_test_models}")

        # Step 1: Score all (sample, objective) pairs in parallel
        async def run_all_scoring_parallel():
            semaphore = asyncio.Semaphore(max_concurrent)

            async def score_with_semaphore(sample_idx: int, obj_idx: int):
                async with semaphore:
                    scores = await self._group_score_with_objective_async(
                        all_formatted_inputs[sample_idx],
                        all_responses_by_query[sample_idx],
                        predicted_objectives[obj_idx]
                    )
                    return sample_idx, obj_idx, scores

            tasks = [
                score_with_semaphore(sample_idx, obj_idx)
                for sample_idx in range(num_samples)
                for obj_idx in range(num_objectives)
            ]
            return await asyncio.gather(*tasks)

        # Run async scoring
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = None

        scoring_start = time.time()
        if loop and loop.is_running():
            import nest_asyncio
            nest_asyncio.apply()
            results = asyncio.get_event_loop().run_until_complete(run_all_scoring_parallel())
        else:
            results = asyncio.run(run_all_scoring_parallel())

        # Organize scores: all_scores[sample_idx][obj_idx] = scores_list
        all_scores = [[None] * num_objectives for _ in range(num_samples)]
        for sample_idx, obj_idx, scores in results:
            all_scores[sample_idx][obj_idx] = scores

        scoring_time = time.time() - scoring_start
        self.logger.info(f"   Objective scoring completed in {scoring_time:.2f}s ({total_tasks / scoring_time:.1f} tasks/sec)")

        # Step 2: Get ground truth for all samples in batch
        gt_start = time.time()
        all_ground_truths, _ = self._get_group_ground_truth_both(
            all_formatted_inputs, all_responses_by_query
        )
        self.logger.info(f"   Ground truth completed in {time.time() - gt_start:.2f}s")

        # Step 3: Calculate residuals for each (sample, test_model)
        residual_results = []
        for sample_idx in range(num_samples):
            for test_idx in test_model_indices:
                # Extract scores for this specific response
                objective_scores = {
                    f"obj_{i}": all_scores[sample_idx][i][test_idx]
                    for i in range(num_objectives)
                }
                ground_truth = all_ground_truths[sample_idx][test_idx]

                # Get combined prediction
                predicted = self.combination_function.combine_rewards(objective_scores)

                # Calculate squared error
                residual = (ground_truth - predicted) ** 2
                residual_results.append((residual, predicted, ground_truth))

        return residual_results

    def _calculate_residuals_single_batch(
        self,
        all_formatted_inputs: List[str],
        all_responses: List[str],
        predicted_objectives: List[str],
        max_concurrent: int = 50
    ) -> List[Tuple[float, float, float]]:
        """
        Calculate residuals for single (non-group) scoring using parallel batch scoring.

        This is the single-scoring equivalent of _calculate_residuals_group_batch.
        It efficiently computes residuals by:
        1. Scoring all (sample, objective) pairs in parallel using async
        2. Getting ground truth for all samples in batch mode
        3. Calculating residuals for each sample

        Args:
            all_formatted_inputs: All formatted input queries
            all_responses: All responses (one per input)
            predicted_objectives: Objectives to score
            max_concurrent: Maximum concurrent API calls

        Returns:
            List of (residual, predicted_score, ground_truth) tuples for each sample
        """
        num_samples = len(all_formatted_inputs)
        num_objectives = len(predicted_objectives)

        # Check if async is available
        use_async = hasattr(self.objective_scorer, 'async_client') and self.objective_scorer.use_api

        if not use_async:
            # Fallback to sequential scoring
            self.logger.info("Async not available, using sequential residual calculation...")
            results = []
            for sample_idx in range(num_samples):
                input_text = all_formatted_inputs[sample_idx]
                response = all_responses[sample_idx]
                residual = self._calculate_residual(input_text, response, predicted_objectives)
                # _calculate_residual returns just the residual, we need to get predicted and ground_truth
                ground_truth = self._get_ground_truth(input_text, response)
                objective_scores = {}
                for i, obj_desc in enumerate(predicted_objectives):
                    score = self._score_with_objective(input_text, response, obj_desc)
                    objective_scores[f"obj_{i}"] = score
                predicted = self.combination_function.combine_rewards(objective_scores)
                results.append((residual, predicted, ground_truth))
            return results

        total_tasks = num_samples * num_objectives
        self.logger.info(f"\n📊 Batch single residual calculation: {num_samples} samples × {num_objectives} objectives = {total_tasks} tasks")

        # Step 1: Score all (sample, objective) pairs in parallel
        async def run_all_scoring_parallel():
            semaphore = asyncio.Semaphore(max_concurrent)

            async def score_with_semaphore(sample_idx: int, obj_idx: int):
                async with semaphore:
                    score = await self._score_with_objective_async(
                        all_formatted_inputs[sample_idx],
                        all_responses[sample_idx],
                        predicted_objectives[obj_idx]
                    )
                    return sample_idx, obj_idx, score

            tasks = [
                score_with_semaphore(sample_idx, obj_idx)
                for sample_idx in range(num_samples)
                for obj_idx in range(num_objectives)
            ]
            return await asyncio.gather(*tasks)

        # Run async scoring
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = None

        scoring_start = time.time()
        if loop and loop.is_running():
            import nest_asyncio
            nest_asyncio.apply()
            results = asyncio.get_event_loop().run_until_complete(run_all_scoring_parallel())
        else:
            results = asyncio.run(run_all_scoring_parallel())

        # Organize scores: all_scores[sample_idx][obj_idx] = score
        all_scores = [[None] * num_objectives for _ in range(num_samples)]
        for sample_idx, obj_idx, score in results:
            all_scores[sample_idx][obj_idx] = score

        scoring_time = time.time() - scoring_start
        self.logger.info(f"   Objective scoring completed in {scoring_time:.2f}s ({total_tasks / scoring_time:.1f} tasks/sec)")

        # Step 2: Get ground truth for all samples in batch
        gt_start = time.time()
        if isinstance(self.ground_truth_objective, RewardFunction):
            # Use batched compute_reward which automatically uses async if available
            norm_rewards, denorm_rewards, _ = self.ground_truth_objective.compute_reward(
                queries=all_formatted_inputs,
                responses=all_responses
            )
            all_ground_truths = [float(r.item()) for r in norm_rewards]
        else:
            # Fallback to individual scoring
            all_ground_truths = [
                self._get_ground_truth(inp, resp)
                for inp, resp in zip(all_formatted_inputs, all_responses)
            ]
        self.logger.info(f"   Ground truth completed in {time.time() - gt_start:.2f}s")

        # Step 3: Calculate residuals for each sample
        residual_results = []
        for sample_idx in range(num_samples):
            # Build objective scores dict for this sample
            objective_scores = {
                f"obj_{i}": all_scores[sample_idx][i]
                for i in range(num_objectives)
            }
            ground_truth = all_ground_truths[sample_idx]

            # Get combined prediction
            predicted = self.combination_function.combine_rewards(objective_scores)

            # Calculate squared error
            residual = (ground_truth - predicted) ** 2
            residual_results.append((residual, predicted, ground_truth))

        return residual_results

    def _get_or_load_model(self, model_path: str) -> Tuple:
        """Get model from cache or load it, managing LRU eviction."""
        print('Current length of model cache:', len(self.model_cache))
        # If model is already in cache, move it to the end (most recently used)
        if model_path in self.model_cache:
            print(f"Model found in cache: {model_path}")
            self.model_cache.move_to_end(model_path)
            cached = self.model_cache[model_path]
            return cached['model'], cached['tokenizer']
        
        # Check if cache is full and evict least recently used if needed
        if len(self.model_cache) >= self.model_cache_size:
            # Remove least recently used model (first item in OrderedDict)
            lru_path, lru_cache = self.model_cache.popitem(last=False)
            print(f"Evicting model from cache: {lru_path}")
            del lru_cache['model']
            del lru_cache['tokenizer']
            torch.cuda.empty_cache()
        
        # Load the new model
        print(f"Loading model into cache: {model_path}")
        model, tokenizer = self._load_model_impl(model_path)
        
        # Add to cache
        self.model_cache[model_path] = {
            'model': model,
            'tokenizer': tokenizer
        }
        
        return model, tokenizer
    
    def _generate_response(self, model_path: str, input_text) -> str:
        """Generate response from a model."""
        # Check response cache - handle both string and list formats for cache key
        # Properly format multi-turn conversations using chat template
        # input_text_formatted = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
        input_text_formatted = str(input_text)
        if self.response_cache is not None:
            cache_key = f"{model_path}:{input_text_formatted}"
            if cache_key in self.response_cache:
                return self.response_cache[cache_key]
        
        # Get model and tokenizer from cache
        model, tokenizer = self._get_or_load_model(model_path)
        
        # Generate using model and tokenizer
        # Handle both string prompts and message lists
        if isinstance(input_text, list):
            # Apply chat template for message format
            input_ids = tokenizer.apply_chat_template(
                input_text,
                padding=False,
                add_generation_prompt=True,
                truncation=True,
                max_length=1024
            )
            inputs = {'input_ids': torch.tensor([input_ids], device=model.device)}
        else:
            # Handle string prompt as before
            inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                use_cache=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        response = tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        )

        # Cache response
        if self.response_cache is not None:
            cache_key = f"{model_path}:{input_text_formatted}"
            self.response_cache[cache_key] = response

        return response
    
    def _get_or_create_vllm_model(self, model_path: str):
        """Get vLLM model from cache or create it, managing LRU eviction."""
        # If model is already in vLLM cache, move it to the end (most recently used)
        if model_path in self.vllm_cache:
            self.vllm_cache.move_to_end(model_path)
            return self.vllm_cache[model_path]
        
        # Check if cache is full and evict least recently used if needed
        if len(self.vllm_cache) >= self.model_cache_size:
            # Remove least recently used vLLM model
            lru_path, lru_llm = self.vllm_cache.popitem(last=False)
            print(f"Evicting vLLM model from cache: {lru_path}")
            del lru_llm
            torch.cuda.empty_cache()
        
        # Create new vLLM model
        from vllm import LLM
        
        # Check if adapter or full model
        adapter_config_path = os.path.join(model_path, 'adapter_config.json')
        is_adapter = os.path.exists(adapter_config_path)
        
        if is_adapter:
            # Read adapter config to get base model
            with open(adapter_config_path, 'r') as f:
                adapter_config = json.load(f)
            base_model_name = adapter_config.get('base_model_name_or_path')
            if base_model_name is None:
                raise ValueError(f"Base model name not found in adapter config at {adapter_config_path}")
            
            print(f"Loading vLLM model with LoRA adapter into cache: {model_path}")
            
            # Initialize vLLM model with LoRA support
            llm = LLM(
                model=base_model_name,
                trust_remote_code=True,
                max_model_len=2048,
                dtype="bfloat16",
                enable_lora=True,
                max_lora_rank=256,
                gpu_memory_utilization=0.95
            )
            # Store adapter info with model
            llm._adapter_path = model_path
            llm._is_adapter = True
        else:
            print(f"Loading vLLM model into cache: {model_path}")
            llm = LLM(
                model=model_path,
                trust_remote_code=True,
                max_model_len=2048,
                dtype="bfloat16",
                quantization="awq" if "awq" in model_path.lower() else None,
                gpu_memory_utilization=0.95
            )
            llm._is_adapter = False
        
        # Add to cache
        self.vllm_cache[model_path] = llm
        return llm
    
    def _generate_response_batched(self, model_path: str, input_texts: List, batch_size: int = 8) -> List[str]:
        """Generate responses in batches using vLLM if available, otherwise standard batching."""
        responses = []
        
        if self.vllm_available:
            # Use vLLM for efficient batched generation
            from vllm import SamplingParams
            
            print(f"Using vLLM for batched generation from {model_path}")
            
            # Get or create vLLM model from cache
            llm = self._get_or_create_vllm_model(model_path)
            
            # Check if this is an adapter model
            is_adapter = getattr(llm, '_is_adapter', False)
            
            if is_adapter:
                from vllm.lora.request import LoRARequest
                # Create LoRA request using stored adapter path
                lora_request = LoRARequest("adapter", 1, lora_path=llm._adapter_path)
                base_model_name = None
                # Read adapter config to get base model for tokenizer
                adapter_config_path = os.path.join(llm._adapter_path, 'adapter_config.json')
                with open(adapter_config_path, 'r') as f:
                    adapter_config = json.load(f)
                    base_model_name = adapter_config.get('base_model_name_or_path')
            else:
                lora_request = None
                base_model_name = None
            
            # Prepare prompts
            formatted_prompts = []
            # Load tokenizer from appropriate location
            if is_adapter:
                # Check for tokenizer in adapter directory first
                tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
                if os.path.exists(tokenizer_config_path):
                    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
                else:
                    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
            else:
                tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            
            # Set pad token if needed
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            for input_text in input_texts:
                if isinstance(input_text, list):
                    # Apply chat template for message format
                    prompt = tokenizer.apply_chat_template(
                        input_text,
                        tokenize=False,
                        add_generation_prompt=True
                    )
                else:
                    prompt = input_text
                formatted_prompts.append(prompt)
            
            # Set sampling parameters
            sampling_params = SamplingParams(
                temperature=0.7,
                top_p=0.9,
                max_tokens=512,
                skip_special_tokens=True
            )
            
            # Generate all responses at once
            if is_adapter:
                outputs = llm.generate(formatted_prompts, sampling_params, lora_request=lora_request)
                print("Generated with vLLM + LoRA adapter")
            else:
                outputs = llm.generate(formatted_prompts, sampling_params)
            
            # Extract responses
            for output in outputs:
                response = output.outputs[0].text
                responses.append(response)

                # Cache response if enabled
                if self.response_cache is not None:
                    input_idx = outputs.index(output)
                    # Properly format multi-turn conversations using chat template
                    # input_text_formatted = apply_chat_template_to_prompt(self.model_sequence[0], input_texts[input_idx])
                    input_text_formatted = str(input_texts[input_idx])
                    cache_key = f"{model_path}:{input_text_formatted}"
                    self.response_cache[cache_key] = response
            
            # Don't delete llm since it's now cached
            
        else:
            # Fall back to standard batched generation
            responses = self._generate_response_batched_standard(model_path, input_texts, batch_size)
        
        return responses
    
    def _generate_response_batched_standard(self, model_path: str, input_texts: List, batch_size: int = 8) -> List[str]:
        """Standard batched generation without vLLM."""

        # New implementation using centralized generate_responses_batched
        # Get model and tokenizer from cache
        # model, tokenizer = self._get_or_load_model(model_path)

        # Generate all responses using the centralized batched generation function
        responses = generate_responses_batched(
            model_path=model_path,
            # model=model,
            # tokenizer=tokenizer,
            prompts=input_texts,
            max_new_tokens=512,
            batch_size=batch_size,
            temperature=0.7,
            top_p=0.9,
            return_model_and_tokenizer=False
        )

        # Cache responses if enabled
        if self.response_cache is not None:
            for idx, response in enumerate(responses):
                # Properly format multi-turn conversations using chat template
                # input_text_formatted = apply_chat_template_to_prompt(self.model_sequence[0], input_texts[idx])
                input_text_formatted = str(input_texts[idx])
                cache_key = f"{model_path}:{input_text_formatted}"
                self.response_cache[cache_key] = response

        return responses

        # Original implementation commented out below:
        # responses = []
        #
        # # Get model and tokenizer from cache
        # model, tokenizer = self._get_or_load_model(model_path)
        #
        # print(f"Generating responses in batches of {batch_size}")
        #
        # # Process in batches
        # for i in tqdm(range(0, len(input_texts), batch_size), desc="Batches"):
        #     batch = input_texts[i:i+batch_size]
        #     batch_responses = []
        #
        #     # Prepare batch inputs
        #     batch_inputs = []
        #     for input_text in batch:
        #         # Check cache first
        #         cache_key_text = str(input_text[0]['content'])[:100] if isinstance(input_text, list) else input_text[:50]
        #         cache_key = f"{model_path}:{cache_key_text}"
        #
        #         if self.response_cache is not None and cache_key in self.response_cache:
        #             batch_responses.append(self.response_cache[cache_key])
        #             batch_inputs.append(None)  # Mark as cached
        #         else:
        #             batch_inputs.append(input_text)
        #
        #     # Process non-cached inputs
        #     non_cached_inputs = [inp for inp in batch_inputs if inp is not None]
        #
        #     if non_cached_inputs:
        #         # Tokenize all non-cached inputs
        #         tokenized_inputs = []
        #         max_length = 0
        #
        #         for input_text in non_cached_inputs:
        #             if isinstance(input_text, list):
        #                 # Apply chat template for message format
        #                 input_ids = tokenizer.apply_chat_template(
        #                     input_text,
        #                     padding=False,
        #                     add_generation_prompt=True,
        #                     truncation=True,
        #                     max_length=1024
        #                 )
        #                 tokenized_inputs.append(input_ids)
        #             else:
        #                 # Handle string prompt
        #                 tokens = tokenizer(
        #                     input_text,
        #                     return_tensors="pt",
        #                     truncation=True,
        #                     max_length=1024
        #                 )
        #                 tokenized_inputs.append(tokens['input_ids'][0].tolist())
        #
        #             max_length = max(max_length, len(tokenized_inputs[-1]))
        #
        #         # Pad all inputs to same length
        #         padded_inputs = []
        #         attention_masks = []
        #         for input_ids in tokenized_inputs:
        #             padding_length = max_length - len(input_ids)
        #             padded_ids = input_ids + [tokenizer.pad_token_id or tokenizer.eos_token_id] * padding_length
        #             attention_mask = [1] * len(input_ids) + [0] * padding_length
        #             padded_inputs.append(padded_ids)
        #             attention_masks.append(attention_mask)
        #
        #         # Convert to tensors
        #         input_ids_tensor = torch.tensor(padded_inputs, device=model.device)
        #         attention_mask_tensor = torch.tensor(attention_masks, device=model.device)
        #
        #         # Generate for the batch
        #         with torch.no_grad():
        #             outputs = model.generate(
        #                 input_ids=input_ids_tensor,
        #                 attention_mask=attention_mask_tensor,
        #                 max_new_tokens=512,
        #                 do_sample=True,
        #                 temperature=0.7,
        #                 top_p=0.9,
        #                 use_cache=True,
        #                 pad_token_id=tokenizer.eos_token_id
        #             )
        #
        #         # Decode responses
        #         non_cached_responses = []
        #         for j, output in enumerate(outputs):
        #             response = tokenizer.decode(
        #                 output[len(tokenized_inputs[j]):],
        #                 skip_special_tokens=True
        #             )
        #             non_cached_responses.append(response)
        #
        #             # Cache the response
        #             if self.response_cache is not None:
        #                 original_input = non_cached_inputs[j]
        #                 cache_key_text = str(original_input[0]['content'])[:100] if isinstance(original_input, list) else original_input[:50]
        #                 cache_key = f"{model_path}:{cache_key_text}"
        #                 self.response_cache[cache_key] = response
        #
        #     # Combine cached and newly generated responses in correct order
        #     non_cached_idx = 0
        #     for j, inp in enumerate(batch_inputs):
        #         if inp is None:
        #             # Already added from cache
        #             continue
        #         else:
        #             batch_responses.insert(j, non_cached_responses[non_cached_idx])
        #             non_cached_idx += 1
        #
        #     responses.extend(batch_responses)
        #
        # return responses
    
    def _score_with_objective(
        self,
        input_text: str,
        response: str,
        objective_description: str
    ) -> float:
        """Score a response according to an objective using ObjectiveScorer."""
        # Check cache
        # if self.score_cache is not None:
        #     cache_key = f"{objective_description[:50]}:{input_text[:50]}:{response[:50]}"
        #     if cache_key in self.score_cache:
        #         print('INFO: Using cached score')
        #         return self.score_cache[cache_key]
        
        # Use ObjectiveScorer to score the response
        score = self.objective_scorer.score_single_objective(
            query=input_text,
            response=response,
            objective=objective_description
        )

        if (self.ground_truth_objective.normalize_scores) and (not self.objective_scorer.normalize_scores):
            print("Normalizing scorer score to match ground truth scale in calc_objectives_fit")
            score = self.ground_truth_objective._normalize_score(score)
        # Cache the score
        # if self.score_cache is not None:
        #     cache_key = f"{objective_description[:50]}:{input_text[:50]}:{response[:50]}"
        #     self.score_cache[cache_key] = score

        return score

    async def _score_with_objective_async(
        self,
        input_text: str,
        response: str,
        objective_description: str
    ) -> float:
        """
        Async version: Score a response according to an objective using ObjectiveScorer.

        Args:
            input_text: Formatted input query
            response: Response to score
            objective_description: Objective description

        Returns:
            Score for the response
        """
        # Use async scoring
        score = await self.objective_scorer.async_score_single_objective(
            query=input_text,
            response=response,
            objective=objective_description
        )

        # Handle normalization if needed
        if (self.ground_truth_objective.normalize_scores) and (not self.objective_scorer.normalize_scores):
            score = self.ground_truth_objective._normalize_score(score)

        return score

    def _get_ground_truth(self, input_text: str, response: str, denormalize_scores: bool = False) -> float:
        """Get ground truth reward/loss value."""
        # Check cache first if enabled
        if self.score_cache is not None:
            cache_key = f"ground_truth:{input_text}:{response}:{denormalize_scores}"
            if cache_key in self.score_cache:
                return self.score_cache[cache_key]

        if isinstance(self.ground_truth_objective, RewardFunction):
            # Use RewardFunction instance
            # rewards_tensor, _ = self.ground_truth_objective.compute_reward(
            #     queries=[input_text],
            #     responses=[response],
            #     denormalize_scores=denormalize_scores
            # )
            normed_rewards_tensor, denormed_rewards_tensor, _ = self.ground_truth_objective.compute_reward(
                queries=[input_text],
                responses=[response],
                denormalize_scores=denormalize_scores
            )
            rewards_tensor = normed_rewards_tensor if not denormalize_scores else denormed_rewards_tensor
            # Convert to float and normalize to 1-10 scale if needed
            reward_value = float(rewards_tensor[0].item())
            # Assume rewards are in range [-1, 1] or similar, scale to 1-10
            # if -2 <= reward_value <= 2:
            #     # Common range for RL rewards, scale to 1-10
            #     reward_value = (reward_value + 1) * 4.5 + 1
            # Cache the result before returning
            if self.score_cache is not None:
                cache_key = f"ground_truth:{input_text}:{response}:{denormalize_scores}"
                self.score_cache[cache_key] = reward_value
            return reward_value
        elif callable(self.ground_truth_objective):
            # Custom function
            result = self.ground_truth_objective(input_text, response)
            # Cache the result
            if self.score_cache is not None:
                cache_key = f"ground_truth:{input_text}:{response}:{denormalize_scores}"
                self.score_cache[cache_key] = result
            return result
        elif self.ground_truth_objective == 'ppo_reward':
            # Placeholder for PPO reward calculation
            # In practice, this would involve the actual reward model
            result = np.random.randn() * 0.5 + 5.0  # Mock implementation
            # Cache the result
            if self.score_cache is not None:
                cache_key = f"ground_truth:{input_text}:{response}:{denormalize_scores}"
                self.score_cache[cache_key] = result
            return result
        elif self.ground_truth_objective == 'dpo_loss':
            # Placeholder for DPO loss calculation
            result = np.random.randn() * 0.2 + 0.5  # Mock implementation
            # Cache the result
            if self.score_cache is not None:
                cache_key = f"ground_truth:{input_text}:{response}:{denormalize_scores}"
                self.score_cache[cache_key] = result
            return result
        else:
            raise ValueError(f"Unknown ground truth type: {self.ground_truth_objective}")
    
    def _load_model_impl(self, model_path: str) -> Tuple:
        """Load a model and tokenizer, returning them as a tuple."""
        # Check if adapter or full model
        adapter_config_path = os.path.join(model_path, 'adapter_config.json')
        is_adapter = os.path.exists(adapter_config_path)
        
        if is_adapter:
            # Load adapter
            with open(adapter_config_path, 'r') as f:
                adapter_config = json.load(f)
            base_model_name = adapter_config.get('base_model_name_or_path')
            if base_model_name is None:
                raise ValueError(f"Base model name not found in adapter config at {adapter_config_path}")
            
            # Check for tokenizer in adapter directory
            tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
            if os.path.exists(tokenizer_config_path):
                tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            else:
                tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
                if tokenizer.pad_token is None:
                    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            
            # Load base model with quantization
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4'
            )
            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                quantization_config=bnb_config,
                torch_dtype=torch.bfloat16,
                device_map=self.device,
                trust_remote_code=True
            )
            
            # Resize embeddings if needed
            if len(tokenizer) != base_model.config.vocab_size:
                base_model.resize_token_embeddings(len(tokenizer))
            
            # Apply adapter
            model = PeftModel.from_pretrained(base_model, model_path)
        else:
            # Load full model
            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4'
            )
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16,
                quantization_config=bnb_config,
                device_map=self.device,
                trust_remote_code=True
            )
        
        # Set pad token if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        return model, tokenizer
    
    def cleanup_model_cache(self):
        """Clean up cached models from memory."""
        # Clean up standard model cache
        for model_path, cache_entry in self.model_cache.items():
            if 'model' in cache_entry:
                del cache_entry['model']
            if 'tokenizer' in cache_entry:
                del cache_entry['tokenizer']
        self.model_cache.clear()
        
        # Clean up vLLM cache
        for model_path, llm in self.vllm_cache.items():
            del llm
        self.vllm_cache.clear()
        
        torch.cuda.empty_cache()
    
    def _extract_score(self, text: str) -> float:
        """Extract numerical score from text."""
        import re
        # Look for numbers between 1-10
        match = re.search(r'(\d+(?:\.\d+)?)', text)
        if match:
            score = float(match.group(1))
            return min(max(score, 1.0), 10.0)
        return 5.0  # Default middle score
    
    def get_combination_weights(self) -> Optional[Dict[str, float]]:
        """
        Get the learned weights from the combination function (if applicable).
        
        Returns:
            Dictionary mapping objective indices to weights, or None
        """
        if self.combination_function is None:
            return None
        
        func = self.combination_function.combination_function
        
        # Extract weights based on function type
        if hasattr(func, 'weights'):
            # Linear function with manual weights
            return func.weights
        elif hasattr(func, 'model') and hasattr(func.model, 'coef_'):
            # Linear regression
            weights = {}
            for i, coef in enumerate(func.model.coef_):
                weights[f"obj_{i}"] = float(coef)
            return weights
        
        return None
    
    def __del__(self):
        """Cleanup when object is destroyed."""
        self.cleanup_model_cache()