"""
Objective Verification Classes for Objectives Discovery

This module implements classes to verify that discovered objectives satisfy key properties:
1. Human-interpretability: The objective's scoring aligns with human understanding
2. Predictable trend: The objective follows a predictable pattern over model iterations
"""

import torch
import numpy as np
import asyncio
import time
from typing import List, Dict, Tuple, Optional, Union, Any
from abc import ABC, abstractmethod
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from openai import OpenAI, AsyncOpenAI
import re
import os
import json
import random
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend

try:
    # Try relative imports first (when imported as a module)
    from .constants import (
        OBJECTIVE_SCORING_PROMPT,
        HUMAN_SCORING_PROMPT,
        TREND_ANALYSIS_PROMPT,
        OPENAI_API_KEY,
        SCORING_SYSTEM_PROMPT
    )
    from .trend_functions import (
        TREND_FUNCTIONS,
        fit_trend,
        find_best_trend,
        evaluate_trend_fit
    )
    from .objective_scorer import ObjectiveScorer
    from .model_generation import generate_responses_batched, generate_huggingface_response, apply_chat_template_to_prompt
except ImportError:
    # Fall back to absolute imports (when run directly)
    from constants import (
        OBJECTIVE_SCORING_PROMPT,
        HUMAN_SCORING_PROMPT,
        TREND_ANALYSIS_PROMPT,
        OPENAI_API_KEY,
        SCORING_SYSTEM_PROMPT
    )
    from trend_functions import (
        TREND_FUNCTIONS,
        fit_trend,
        find_best_trend,
        evaluate_trend_fit
    )
    from objective_scorer import ObjectiveScorer
    from model_generation import generate_responses_batched, generate_huggingface_response, apply_chat_template_to_prompt


class BaseObjectivesVerifier(ABC):
    """
    Abstract base class for objective verification.
    All verifiers should inherit from this class and implement the verify method.
    """
    
    def __init__(self, objective_description: str, epsilon: float = 0.5):
        """
        Initialize the base verifier.
        
        Args:
            objective_description: Natural language description of the objective
            epsilon: Threshold for verification (default 0.5)
        """
        self.objective_description = objective_description
        self.epsilon = epsilon
    
    @abstractmethod
    def verify(self, **kwargs) -> bool:
        """
        Verify if the objective satisfies the specific criterion.
        
        Returns:
            bool: True if objective satisfies the criterion, False otherwise
        """
        pass
    
    def _normalize_score(self, score: float, min_val: float = 1.0, max_val: float = 10.0) -> float:
        """
        Normalize a score to [0, 1] range.
        
        Args:
            score: Raw score to normalize
            min_val: Minimum value of the score range
            max_val: Maximum value of the score range
            
        Returns:
            float: Normalized score between 0 and 1
        """
        return (score - min_val) / (max_val - min_val)


class HumanInterpretableVerifier(BaseObjectivesVerifier):
    """
    Verifies if an objective is human-interpretable by comparing objective scoring
    function r(x,y) with human scoring function s_h(x,y|n).
    
    An objective is human-interpretable if:
    1/T \sum_{t=1}^T E_{x ~ X, y ~ pi_theta_t(.|x)} [|r_n(x, y) - s_h(x,y|n)|] <= epsilon
    
    This expectation is computed across the sequence of T models.
    """
    
    def __init__(
        self,
        objective_description: str,
        objective_model: Union[str, Any],
        human_models: List[str],
        model_sequence: List[str] = None,
        epsilon: float = 0.15,
        use_normalized_scores: bool = True,
        output_dir: Optional[str] = None,
        max_concurrent: int = 10
    ):
        """
        Initialize the human-interpretable verifier.

        Args:
            objective_description: Natural language description of the objective
            objective_model: Model or model name for objective scoring r(x,y)
            human_models: List of models for human scoring s_h(x,y|n)
                         Models with '/' are treated as HuggingFace models
                         Others are treated as API models (e.g., 'gpt-4', 'claude-3')
            model_sequence: List of paths to model checkpoints pi_theta_1, ..., pi_theta_T
                           These are the models being evaluated across the training trajectory
            epsilon: Maximum allowed average difference between r and s_h (default 0.15)
            use_normalized_scores: Whether to normalize scores to [0, 1] range
            max_concurrent: Maximum number of concurrent API calls (default 10)
        """
        super().__init__(objective_description, epsilon)
        self.objective_model = objective_model
        self.human_models = human_models
        self.model_sequence = model_sequence if model_sequence else []
        self.use_normalized_scores = use_normalized_scores
        self.output_dir = output_dir
        self.max_concurrent = max_concurrent
        
        # Initialize API client if needed
        self.api_client = None
        if any(not self._is_huggingface_model(m) for m in human_models):
            api_key = OPENAI_API_KEY or os.environ.get('OPENAI_API_KEY')
            if api_key:
                self.api_client = OpenAI(api_key=api_key)
        
        # Cache for loaded HuggingFace models
        self.hf_models = {}
        self.hf_tokenizers = {}
        
        # Cache for sequence models (pi_theta_t) - only stores one at a time
        self.sequence_models = {}
        self.sequence_tokenizers = {}
        self.human_model_scorers_cache = {}
        # self.objective_model_scorers_cache = {}
    
    def _is_huggingface_model(self, model_name: str) -> bool:
        """Check if a model name refers to a HuggingFace model."""
        return '/' in model_name
    
    def _load_huggingface_model(self, model_name: str):
        """Load a HuggingFace model and tokenizer."""
        if model_name not in self.hf_models:
            print(f"Loading HuggingFace model: {model_name}")
            self.hf_tokenizers[model_name] = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True
            )
            
            # Use quantization for large models to reduce memory usage
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4'
            )
            
            self.hf_models[model_name] = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=bnb_config,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True
            )
            
            # Set pad token if not present
            if self.hf_tokenizers[model_name].pad_token is None:
                self.hf_tokenizers[model_name].pad_token = self.hf_tokenizers[model_name].eos_token
        
        return self.hf_models[model_name], self.hf_tokenizers[model_name]
    
    def _score_with_huggingface(
        self,
        model_name: str,
        prompt: str
    ) -> float:
        """Score using a HuggingFace model."""
        model, tokenizer = self._load_huggingface_model(model_name)

        # Use the standardized scoring function
        generated_text = generate_huggingface_response(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            system_prompt=SCORING_SYSTEM_PROMPT,
            max_new_tokens=10,
            do_sample=False,
            temperature=0.1
        )

        # Extract score
        return self._extract_score(generated_text)
    
    def _score_with_api(
        self,
        model_name: str,
        prompt: str
    ) -> float:
        """Score using an API model."""
        if not self.api_client:
            print(f"Warning: API client not initialized. Returning default score.")
            return 5.0
        
        try:
            # Map model names to API endpoints
            if model_name.startswith("gpt"):
                api_model = model_name
            elif model_name == "claude-3":
                # Note: This would require Anthropic API client
                raise NotImplementedError("Claude API not implemented")
                print(f"Claude API not implemented. Using default score.")
                return 5.0
            else:
                api_model = "gpt-4o-mini"  # Default fallback
            
            response = self.api_client.chat.completions.create(
                model=api_model,
                messages=[
                    {"role": "system", "content": SCORING_SYSTEM_PROMPT},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.1,
                max_tokens=10
            )
            
            generated_text = response.choices[0].message.content.strip()
            return self._extract_score(generated_text)
            
        except Exception as e:
            print(f"Error calling API for {model_name}: {e}. Using default score.")
            return 5.0
    
    def _extract_score(self, text: str) -> float:
        """Extract numerical score from model output."""
        try:
            # Find numbers in the text
            numbers = re.findall(r'\b\d+(?:\.\d+)?\b', text.strip())
            if numbers:
                score = float(numbers[0])
                # Clamp to valid range
                return max(1.0, min(10.0, score))
            else:
                print(f"Warning: Could not extract score from '{text}'. Using default.")
                return 5.0
        except Exception as e:
            print(f"Error extracting score: {e}. Using default.")
            return 5.0
    
    def compute_objective_score(
        self,
        input_text: str,
        response_text: str,
        dataset_type: str,
    ) -> float:
        """
        Compute objective score r(x,y) using the objective model.
        
        Args:
            input_text: Input prompt x
            response_text: Model response y
            
        Returns:
            float: Objective score
        """
        print("Computing objective score...")
        prompt = OBJECTIVE_SCORING_PROMPT.format(
            objective_description=self.objective_description,
            input_text=input_text,
            response_text=response_text
        )

        if not isinstance(self.objective_model, ObjectiveScorer):
            use_api = self.objective_model.startswith("gpt")
            self.objective_model = ObjectiveScorer(
                    use_detailed_rubric=True,  # Use detailed rubrics for better scoring
                    dataset_type=dataset_type,
                    use_api=use_api,
                    model_name=self.objective_model,
                    max_length=4096,  # Use a reasonable default
                    load_quantized=not use_api,  # Use quantization for local models
                    # cache_file="./custom_rubrics_cache.json",
                    cache_dir=None,
                    save_dir=self.output_dir  # HumanInterpretableVerifier doesn't have save_dir
                )
        
        # Check if objective_model is ObjectiveScorer instance
        if isinstance(self.objective_model, ObjectiveScorer):
            # Use ObjectiveScorer's score_single_objective method
            score = self.objective_model.score_single_objective(
                input_text, 
                response_text, 
                self.objective_description
            )
            print("Using ObjectiveScorer instance for scoring.")
        elif callable(self.objective_model):
            score = self.objective_model(input_text, response_text)
        elif isinstance(self.objective_model, str):
            if self._is_huggingface_model(self.objective_model):
                score = self._score_with_huggingface(self.objective_model, prompt)
            else:
                score = self._score_with_api(self.objective_model, prompt)
                print(f"Objective model API score: {score}")
        else:
            # Assume it's a model object with a scoring method
            score = self.objective_model.score(input_text, response_text)
        
        if self.use_normalized_scores:
            score = self._normalize_score(score)
        
        return score
    
    def compute_human_score(
        self,
        input_text: str,
        response_text: str,
        dataset_type: str,
        return_all_scores: bool = False
    ) -> Union[float, Tuple[float, Dict[str, float]]]:
        """
        Compute human score s_h(x,y|n) by averaging scores from multiple human models.

        Args:
            input_text: Input prompt x
            response_text: Model response y
            dataset_type: Dataset type for scoring
            return_all_scores: If True, return both average and individual scores

        Returns:
            float or Tuple[float, Dict]: Average human score, or (average, individual_scores_dict) if return_all_scores=True
        """
        print("Computing human score...")
        scores = []
        individual_scores = {}  # Store scores for each human model

        for model_name in self.human_models:
            if model_name not in self.human_model_scorers_cache:
                print('Creating new ObjectiveScorer instance for human model {} scoring...'.format(model_name))
                # Create ObjectiveScorer instance for this human model
                use_api = model_name.startswith("gpt")
                human_model_scorer = ObjectiveScorer(
                    use_detailed_rubric=True,  # Use detailed rubrics for better scoring
                    dataset_type=dataset_type,
                    use_api=use_api,
                    model_name=model_name,
                    max_length=4096,  # Use a reasonable default
                    load_quantized=not use_api,  # Use quantization for local models
                    # cache_file="./custom_rubrics_cache.json",
                    cache_dir=None,
                    save_dir=self.output_dir  # HumanInterpretableVerifier doesn't have save_dir
                )
                self.human_model_scorers_cache[model_name] = human_model_scorer

            self.objective_model = self.human_model_scorers_cache[model_name]
            score = self.objective_model.score_single_objective(
                        input_text,
                        response_text,
                        self.objective_description
                    )
            print('Using ObjectiveScorer instance for human model {} scoring...'.format(model_name))

            # prompt = HUMAN_SCORING_PROMPT.format(
            #     objective_description=self.objective_description,
            #     input_text=input_text,
            #     response_text=response_text
            # )
            # if isinstance(self.objective_model, ObjectiveScorer) and (('gpt' in model_name and self.objective_model.use_api) or (not self.objective_model.use_api and model_name==self.objective_model.model_name)):
            #     print(f"Using ObjectiveScorer instance for human scoring with model {model_name}.")
            #     score = self.objective_model.score_single_objective(
            #         input_text, 
            #         response_text, 
            #         self.objective_description
            #     )
            #     breakpoint()
            # elif self._is_huggingface_model(model_name):
            #     score = self._score_with_huggingface(model_name, prompt)
            #     breakpoint()
            # else:
            #     score = self._score_with_api(model_name, prompt)
            #     print(f"Human model {model_name} API score: {score}")
            
            if self.use_normalized_scores:
                score = self._normalize_score(score)
            
            scores.append(score)
            individual_scores[model_name] = score

        # Calculate average of all human model scores
        avg_score = float(np.mean(scores)) if scores else 5.0

        if return_all_scores:
            return avg_score, individual_scores
        else:
            return avg_score

    async def async_compute_objective_score(
        self,
        input_text: str,
        response_text: str,
        dataset_type: str,
    ) -> float:
        """
        Async version of compute_objective_score using ObjectiveScorer.async_score_single_objective.

        Args:
            input_text: Input prompt x
            response_text: Model response y
            dataset_type: Dataset type for scoring

        Returns:
            float: Objective score
        """
        # Ensure objective_model is an ObjectiveScorer instance
        if not isinstance(self.objective_model, ObjectiveScorer):
            use_api = self.objective_model.startswith("gpt") if isinstance(self.objective_model, str) else False
            self.objective_model = ObjectiveScorer(
                use_detailed_rubric=True,
                dataset_type=dataset_type,
                use_api=use_api,
                model_name=self.objective_model if isinstance(self.objective_model, str) else "gpt-4o-mini",
                max_length=4096,
                load_quantized=not use_api,
                cache_dir=None,
                save_dir=self.output_dir,
                # use_async=True
            )

        # Use async scoring
        score = await self.objective_model.async_score_single_objective(
            input_text,
            response_text,
            self.objective_description
        )

        if self.use_normalized_scores:
            score = self._normalize_score(score)

        return score

    async def async_compute_human_score(
        self,
        input_text: str,
        response_text: str,
        dataset_type: str,
    ) -> Tuple[float, Dict[str, float]]:
        """
        Async version of compute_human_score using ObjectiveScorer.async_score_single_objective.
        Scores with all human models in parallel.

        Args:
            input_text: Input prompt x
            response_text: Model response y
            dataset_type: Dataset type for scoring

        Returns:
            Tuple[float, Dict]: (average human score, individual scores dict)
        """
        # Ensure all human model scorers are initialized
        for model_name in self.human_models:
            if model_name not in self.human_model_scorers_cache:
                use_api = model_name.startswith("gpt")
                human_model_scorer = ObjectiveScorer(
                    use_detailed_rubric=True,
                    dataset_type=dataset_type,
                    use_api=use_api,
                    model_name=model_name,
                    max_length=4096,
                    load_quantized=not use_api,
                    cache_dir=None,
                    save_dir=self.output_dir,
                    # use_async=True
                )
                self.human_model_scorers_cache[model_name] = human_model_scorer

        # Score with all human models in parallel
        async def score_with_model(model_name: str) -> Tuple[str, float]:
            scorer = self.human_model_scorers_cache[model_name]
            score = await scorer.async_score_single_objective(
                input_text,
                response_text,
                self.objective_description
            )
            if self.use_normalized_scores:
                score = self._normalize_score(score)
            return model_name, score

        # Run all human model scores in parallel
        results = await asyncio.gather(*[score_with_model(m) for m in self.human_models])

        individual_scores = {model_name: score for model_name, score in results}
        scores = list(individual_scores.values())
        avg_score = float(np.mean(scores)) if scores else 5.0

        return avg_score, individual_scores

    def _load_sequence_model(self, model_path: str):
        """
        Load a model from the training sequence (pi_theta_t).
        Only keeps one model in memory at a time to avoid GPU OOM.
        
        Args:
            model_path: Path to the model checkpoint
            
        Returns:
            Tuple of (model, tokenizer)
        """
        if model_path not in self.sequence_models:
            # Clear existing cached models to free GPU memory
            if self.sequence_models:
                print("Clearing previous sequence model from memory...")
                for cached_path in list(self.sequence_models.keys()):
                    del self.sequence_models[cached_path]
                    del self.sequence_tokenizers[cached_path]
                torch.cuda.empty_cache()
            
            print(f"Loading sequence model: {model_path}")
            
            # Check if this is an adapter model
            adapter_config_path = os.path.join(model_path, 'adapter_config.json')
            is_adapter = os.path.exists(adapter_config_path)
            
            if is_adapter:
                # Load adapter config to get base model
                import json
                with open(adapter_config_path, 'r') as f:
                    adapter_config = json.load(f)
                base_model_name = adapter_config.get('base_model_name_or_path', 'meta-llama/Llama-3.1-8B')
                
                # Check if tokenizer files exist in the adapter directory
                tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
                
                # If tokenizer was saved with the adapter, use it to ensure compatibility
                if os.path.exists(tokenizer_config_path):
                    self.sequence_tokenizers[model_path] = AutoTokenizer.from_pretrained(
                        model_path,
                        trust_remote_code=True
                    )
                else:
                    # Load tokenizer from base model
                    self.sequence_tokenizers[model_path] = AutoTokenizer.from_pretrained(
                        base_model_name,
                        trust_remote_code=True
                    )
                    
                # Add padding token if not present
                if self.sequence_tokenizers[model_path].pad_token is None:
                    self.sequence_tokenizers[model_path].add_special_tokens({"pad_token": "[PAD]"})
                
                # Load base model with quantization
                bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.bfloat16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type='nf4'
                )
                
                base_model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    quantization_config=bnb_config,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                    trust_remote_code=True
                )
                
                # Only resize embeddings if vocabulary size changed
                if len(self.sequence_tokenizers[model_path]) != base_model.config.vocab_size:
                    print(f"Resizing model embeddings from {base_model.config.vocab_size} to {len(self.sequence_tokenizers[model_path])}")
                    base_model.resize_token_embeddings(len(self.sequence_tokenizers[model_path]))
                
                # Apply adapter
                self.sequence_models[model_path] = PeftModel.from_pretrained(base_model, model_path)
            else:
                # Load full model
                self.sequence_tokenizers[model_path] = AutoTokenizer.from_pretrained(
                    model_path,
                    trust_remote_code=True
                )
                
                self.sequence_models[model_path] = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map="auto"
                )
            
            # Set pad token if not present
            if self.sequence_tokenizers[model_path].pad_token is None:
                self.sequence_tokenizers[model_path].pad_token = self.sequence_tokenizers[model_path].eos_token
        
        return self.sequence_models[model_path], self.sequence_tokenizers[model_path]
    
    def generate_response(
        self,
        model_path: str,
        input_text,
        max_new_tokens: int = 128
    ) -> str:
        """
        Generate a response y ~ pi_theta(.|x) for a given input using a model from the sequence.
        
        Args:
            model_path: Path to the model checkpoint
            input_text: Input prompt (string or list of message dicts)
            max_new_tokens: Maximum number of tokens to generate
            
        Returns:
            str: Generated response y
        """
        model, tokenizer = self._load_sequence_model(model_path)
        
        # Handle both string prompts and message lists
        if isinstance(input_text, list):
            # Apply chat template for message format
            input_ids = tokenizer.apply_chat_template(
                input_text,
                padding=False,
                add_generation_prompt=True,
                truncation=True,
                max_length=1024
            )
            inputs = {'input_ids': torch.tensor([input_ids], device=model.device)}
        else:
            # Handle string prompt as before
            inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        # Generate response
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode response (only the generated part)
        generated_text = tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        )
        
        return generated_text

    def _compute_alignment_scores_async(
        self,
        all_inputs: List[str],
        all_responses: List[str],
        dataset_type: str,
        max_concurrent: int = 50
    ) -> Tuple[List[float], List[float], List[Dict[str, float]]]:
        """
        Compute objective and human scores for all (input, response) pairs in parallel.

        Follows the pattern from calc_objectives_fit.py _process_single_scoring_async.
        Parallelizes ALL scoring tasks across ALL samples simultaneously.

        Args:
            all_inputs: List of formatted input strings
            all_responses: List of response strings
            dataset_type: Dataset type for scoring
            max_concurrent: Maximum concurrent API calls (default: 50)

        Returns:
            Tuple of (obj_scores, human_avg_scores, human_individual_scores)
        """
        scoring_start_time = time.time()

        # Check if we can use async (check if any scorer has async capability)
        # Initialize objective scorer if needed to check async capability
        if not isinstance(self.objective_model, ObjectiveScorer):
            use_api = self.objective_model.startswith("gpt") if isinstance(self.objective_model, str) else False
            self.objective_model = ObjectiveScorer(
                use_detailed_rubric=True,
                dataset_type=dataset_type,
                use_api=use_api,
                model_name=self.objective_model if isinstance(self.objective_model, str) else "gpt-4o-mini",
                max_length=4096,
                load_quantized=not use_api,
                cache_dir=None,
                save_dir=self.output_dir,
                # use_async=True
            )

        use_async = hasattr(self.objective_model, 'async_client') and self.objective_model.use_api

        if not use_async:
            # Fallback to sequential scoring
            print("Async not available, falling back to sequential scoring...")
            obj_scores = []
            human_avg_scores = []
            human_individual_scores = []

            for idx, (input_str, response) in enumerate(zip(all_inputs, all_responses)):
                obj_score = self.compute_objective_score(input_str, response, dataset_type)
                human_avg, human_individual = self.compute_human_score(
                    input_str, response, dataset_type, return_all_scores=True
                )
                obj_scores.append(obj_score)
                human_avg_scores.append(human_avg)
                human_individual_scores.append(human_individual)

                if (idx + 1) % 10 == 0:
                    print(f"  Scored {idx + 1}/{len(all_inputs)} pairs")

            return obj_scores, human_avg_scores, human_individual_scores

        num_samples = len(all_inputs)
        print(f"Scoring {num_samples} samples in parallel (max concurrent: {max_concurrent})...")

        async def run_all_scoring_parallel():
            """Run ALL scoring tasks in parallel with rate limiting."""
            semaphore = asyncio.Semaphore(max_concurrent)

            async def score_sample(idx: int, input_str: str, response: str):
                """Score a single (input, response) pair - both objective and human scores."""
                async with semaphore:
                    obj_score = await self.async_compute_objective_score(input_str, response, dataset_type)
                    human_avg, human_individual = await self.async_compute_human_score(input_str, response, dataset_type)
                    return idx, obj_score, human_avg, human_individual

            # Create tasks for all samples
            tasks = [
                score_sample(idx, input_str, response)
                for idx, (input_str, response) in enumerate(zip(all_inputs, all_responses))
            ]

            # Run all tasks in parallel (rate-limited by semaphore)
            results = await asyncio.gather(*tasks)

            # Sort results by index and extract scores
            results.sort(key=lambda x: x[0])
            obj_scores = [r[1] for r in results]
            human_avg_scores = [r[2] for r in results]
            human_individual_scores = [r[3] for r in results]

            return obj_scores, human_avg_scores, human_individual_scores

        # Run the async scoring
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = None

        if loop and loop.is_running():
            # Already in an async context (e.g., Jupyter notebook)
            print("Running in an async context, applying nest_asyncio...")
            import nest_asyncio
            nest_asyncio.apply()
            results = asyncio.get_event_loop().run_until_complete(run_all_scoring_parallel())
        else:
            # Normal Python - just run
            results = asyncio.run(run_all_scoring_parallel())

        scoring_time = time.time() - scoring_start_time
        print(f"  Scoring completed in {scoring_time:.2f}s ({num_samples / scoring_time:.1f} samples/sec)")

        return results

    def compute_alignment(
        self,
        dataset: List[Dict[str, str]],
        dataset_type: str,
        use_provided_responses: bool = False,
        batch_size: int = 8,
    ) -> Tuple[float, bool, Dict[str, Any]]:
        """
        Compute alignment between objective and human scoring across the model sequence.

        This implements Equation (1) from Methods.md:
        1/T \sum_{t=1}^T E_{x ~ X, y ~ pi_theta_t(.|x)} [|r_n(x, y) - s_h(x,y|n)|] <= epsilon

        The expectation is computed by:
        1. For each model pi_theta_t in the sequence
        2. Generate responses for all prompts in batch
        3. Compute |r_n(x, y) - s_h(x,y|n)| for each pair
        4. Average across all (model, prompt) pairs

        Args:
            dataset: List of dicts with 'input' and optionally 'response' keys
            use_provided_responses: If True, use responses from dataset instead of generating
            batch_size: Batch size for generation (only used when generating responses)

        Returns:
            Tuple[float, bool, Dict]: (average_difference, is_interpretable, score_details)
        """
        if not self.model_sequence and not use_provided_responses:
            raise ValueError("No model sequence provided. Either provide model_sequence or set use_provided_responses=True")

        all_differences = []
        all_obj_scores = []  # Track all objective scores
        all_human_scores = []  # Track all average human scores
        all_individual_human_scores = []  # Track individual scores from each human model
        total_evaluations = 0

        # If using model sequence, iterate through each model
        if self.model_sequence and not use_provided_responses:
            print(f"Computing alignment across {len(self.model_sequence)} models...")

            for model_idx, model_path in enumerate(self.model_sequence):
                print(f"\nProcessing model {model_idx + 1}/{len(self.model_sequence)}: {model_path}")

                # Load model and tokenizer once for this checkpoint
                model, tokenizer = self._load_sequence_model(model_path)

                # Collect all prompts for batched generation
                prompts = []
                for sample in dataset:
                    input_text = sample.get('input', sample.get('prompt', ''))
                    prompts.append(input_text)

                # Generate all responses in batch using the loaded model
                print(f"  Generating {len(prompts)} responses in batch...")
                responses = generate_responses_batched(
                    model_path=model_path,
                    model=model,
                    tokenizer=tokenizer,
                    prompts=prompts,
                    max_new_tokens=512,
                    batch_size=batch_size,
                    temperature=0.7,
                    top_p=0.9
                )

                # Prepare formatted inputs for scoring
                print(f"  Formatting inputs for scoring...")
                all_input_strs = []
                for prompt in prompts:
                    if isinstance(prompt, list):
                        input_str = tokenizer.apply_chat_template(
                            prompt,
                            tokenize=False,
                            add_generation_prompt=False
                        )
                    else:
                        input_str = tokenizer.apply_chat_template(
                            [{"role": "user", "content": prompt}],
                            tokenize=False,
                            add_generation_prompt=False
                        )
                    all_input_strs.append(input_str)

                # Compute all scores in parallel using async helper
                print(f"  Computing alignment scores (parallel)...")
                model_obj_scores, model_human_scores, model_individual_human_scores = \
                    self._compute_alignment_scores_async(
                        all_input_strs,
                        responses,
                        dataset_type,
                        max_concurrent=self.max_concurrent
                    )

                # Process results
                curr_time_differences = []
                for idx, (obj_score, human_score_avg, human_score_individual) in enumerate(
                    zip(model_obj_scores, model_human_scores, model_individual_human_scores)
                ):
                    # Store scores for detailed logging
                    all_obj_scores.append(obj_score)
                    all_human_scores.append(human_score_avg)
                    all_individual_human_scores.append(human_score_individual)

                    # Calculate absolute difference
                    diff = abs(obj_score - human_score_avg)
                    curr_time_differences.append(diff)
                    total_evaluations += 1

                # Store average difference for this model
                model_avg_diff = np.mean(curr_time_differences)
                all_differences.append(model_avg_diff)
                print(f"  Model {model_idx + 1} average difference: {model_avg_diff:.3f}")

        # Alternative: use provided responses (for testing or when responses are pre-generated)
        else:
            print("Using provided responses from dataset...")

            # Collect all (input, response) pairs first
            all_input_strs = []
            all_response_strs = []
            for sample in dataset:
                input_text = sample.get('input', sample.get('prompt', ''))
                response_text = sample.get('response', sample.get('output', ''))

                # Skip samples without responses
                if not response_text:
                    continue

                # Properly format multi-turn conversations using chat template
                if self.model_sequence:
                    input_str = apply_chat_template_to_prompt(self.model_sequence[0], input_text)
                else:
                    # Fallback if no model_sequence available
                    input_str = input_text[0]['content'] if isinstance(input_text, list) else input_text

                all_input_strs.append(input_str)
                all_response_strs.append(response_text)

            if not all_input_strs:
                print("Warning: No valid samples found in dataset")
            else:
                # Compute all scores in parallel using async helper
                print(f"  Computing alignment scores for {len(all_input_strs)} samples (parallel)...")
                provided_obj_scores, provided_human_scores, provided_individual_human_scores = \
                    self._compute_alignment_scores_async(
                        all_input_strs,
                        all_response_strs,
                        dataset_type,
                        max_concurrent=self.max_concurrent
                    )

                # Process results
                for obj_score, human_score_avg, human_score_individual in zip(
                    provided_obj_scores, provided_human_scores, provided_individual_human_scores
                ):
                    # Store scores for detailed logging
                    all_obj_scores.append(obj_score)
                    all_human_scores.append(human_score_avg)
                    all_individual_human_scores.append(human_score_individual)

                    # Calculate absolute difference
                    diff = abs(obj_score - human_score_avg)
                    all_differences.append(diff)
                    total_evaluations += 1

        # Calculate average difference across all models and samples
        avg_difference = float(np.mean(all_differences)) if all_differences else float('inf')

        # Check if interpretable
        is_interpretable = avg_difference <= self.epsilon

        print(f"\nTotal evaluations: {total_evaluations}")
        print(f"Models evaluated: {len(self.model_sequence) if self.model_sequence else 'N/A (using provided responses)'}")
        print(f"Samples per model: {len(dataset)}")

        # Prepare detailed score statistics
        score_details = {
            'avg_objective_score': float(np.mean(all_obj_scores)) if all_obj_scores else 0.0,
            'avg_human_score': float(np.mean(all_human_scores)) if all_human_scores else 0.0,
            'std_objective_score': float(np.std(all_obj_scores)) if all_obj_scores else 0.0,
            'std_human_score': float(np.std(all_human_scores)) if all_human_scores else 0.0,
            'min_objective_score': float(np.min(all_obj_scores)) if all_obj_scores else 0.0,
            'max_objective_score': float(np.max(all_obj_scores)) if all_obj_scores else 0.0,
            'min_human_score': float(np.min(all_human_scores)) if all_human_scores else 0.0,
            'max_human_score': float(np.max(all_human_scores)) if all_human_scores else 0.0,
            'num_samples': total_evaluations,
            'all_individual_human_scores': all_individual_human_scores  # All individual scores from each human model
        }

        # Calculate per-model statistics if we have individual scores
        if all_individual_human_scores:
            # Get all unique human model names
            model_names = set()
            for scores_dict in all_individual_human_scores:
                model_names.update(scores_dict.keys())

            # Calculate stats for each human model
            per_model_stats = {}
            for model_name in model_names:
                model_scores = [scores_dict.get(model_name, 0.0) for scores_dict in all_individual_human_scores]
                per_model_stats[model_name] = {
                    'avg': float(np.mean(model_scores)),
                    'std': float(np.std(model_scores)),
                    'min': float(np.min(model_scores)),
                    'max': float(np.max(model_scores))
                }
            score_details['per_human_model_stats'] = per_model_stats

        return avg_difference, is_interpretable, score_details
    
    def verify(
        self,
        dataset: Union[List[Dict[str, str]], str],
        dataset_type: str,
        sample_size: Optional[int] = None,
        use_provided_responses: bool = False,
        batch_size: int = 8
    ) -> bool:
        """
        Verify if the objective is human-interpretable.

        This computes: E_{x ~ X, y ~ pi_theta(.|x)} [|r_n(x, y) - s_h(x,y|n)|]
        averaged across all models in the sequence.

        Args:
            dataset: Either a list of samples or a dataset name/path
            sample_size: Optional number of samples to use per model for verification
            use_provided_responses: If True, use responses from dataset instead of generating
                                   from model sequence
            batch_size: Batch size for generation (only used when generating responses)

        Returns:
            bool: True if objective is human-interpretable, False otherwise
        """
        # Load dataset if it's a string
        if isinstance(dataset, str):
            # Handle loading from HuggingFace datasets or local files
            if '/' in dataset:  # Assume HuggingFace dataset
                ds = load_dataset(dataset, split='train')
                # Convert to our format
                dataset = [
                    {'input': sample.get('prompt', sample.get('input', '')),
                     'response': sample.get('response', sample.get('output', ''))}
                    for sample in ds
                ]
            else:
                # Load from local file (implement as needed)
                raise NotImplementedError("Local file loading not implemented")

        # Sample if needed
        if sample_size and len(dataset) > sample_size:
            dataset = random.sample(dataset, sample_size)

        # Compute alignment across model sequence
        avg_difference, is_interpretable, score_details = self.compute_alignment(
            dataset,
            dataset_type,
            use_provided_responses=use_provided_responses,
            batch_size=batch_size
        )

        print(f"\nHuman-Interpretability Verification Results:")
        print(f"Objective: {self.objective_description}")
        print(f"Average difference: {avg_difference:.4f}")
        print(f"Average objective score: {score_details.get('avg_objective_score', 0):.4f}")
        print(f"Average human score: {score_details.get('avg_human_score', 0):.4f}")

        # Print per-model statistics if available
        if 'per_human_model_stats' in score_details:
            print(f"\nPer-model human score statistics:")
            for model_name, stats in score_details['per_human_model_stats'].items():
                print(f"  {model_name}:")
                print(f"    Average: {stats['avg']:.4f}")
                print(f"    Std Dev: {stats['std']:.4f}")
                print(f"    Range: [{stats['min']:.4f}, {stats['max']:.4f}]")

        print(f"\nThreshold (epsilon): {self.epsilon}")
        print(f"Is human-interpretable: {is_interpretable}")

        return is_interpretable


class PredictableTrendVerifier(BaseObjectivesVerifier):
    """
    Verifies if an objective follows a predictable trend across model iterations.
    
    An objective follows a predictable trend if the sequence of scores
    V^1(r), ..., V^T(r) can be fit with a function from a pre-defined class F_trend.
    
    This implements Equations (2), (3), and (4) from Methods.md:
    - Eq (2): V_n^t(r) = E_{x ~ X, y ~ pi_theta_t(.|x)} [r_n(x, y)]
    - Eq (3): f* = argmin_{f in F_trend} sum_{t=2}^T L(f(t|V^{t-1}(r)), V^t(r))
    - Eq (4): (1/(T-1)) * sum_{t=2}^T L(f*(t|V^{t-1}(r)), V^t(r)) <= epsilon_trend
    """
    
    def __init__(
        self,
        objective_description: str,
        objective_model: Union[str, Any],
        model_sequence: List[str],
        epsilon: float = 0.1,
        trend_function_types: Optional[List[str]] = None,
        use_normalized_scores: bool = True,
        save_dir: Optional[str] = None,
        group_scoring: bool = False,
        max_concurrent: int = 10
        # group_scoring: bool = True
    ):
        """
        Initialize the predictable trend verifier.

        Args:
            objective_description: Natural language description of the objective
            objective_model: Model or model name for objective scoring r(x,y)
            model_sequence: List of paths to model checkpoints pi_theta_1, ..., pi_theta_T
            epsilon: Maximum allowed average prediction error (default 0.1)
            trend_function_types: List of trend types to try (default: ['linear', 'flat', 'converging'])
            use_normalized_scores: Whether to normalize scores to [0, 1] range
            save_dir: Optional directory to save trend plots
            group_scoring: Whether to use group scoring for objectives (default True)
            max_concurrent: Maximum number of concurrent API calls (default 10)
        """
        super().__init__(objective_description, epsilon)
        self.objective_model = objective_model
        self.model_sequence = model_sequence
        self.use_normalized_scores = use_normalized_scores
        self.save_dir = save_dir
        self.group_scoring = group_scoring
        self.max_concurrent = max_concurrent
        
        # Default trend functions to try (strictly increasing only)
        if trend_function_types is None:
            # self.trend_function_types = ['linear', 'flat', 'exponential_sat', 'power_law', 'logarithmic']  # 'flat' excluded for strictly increasing
            self.trend_function_types = ['linear', 'exponential_sat', 'power_law', 'logarithmic']  # 'flat' excluded for strictly increasing
        else:
            self.trend_function_types = trend_function_types
        
        # Initialize API client if needed
        self.api_client = None
        if isinstance(objective_model, str) and (not self._is_huggingface_model(objective_model)) and (not isinstance(objective_model, ObjectiveScorer)):
            api_key = OPENAI_API_KEY or os.environ.get('OPENAI_API_KEY')
            if api_key:
                self.api_client = OpenAI(api_key=api_key)
        
        # Cache for models and results - only stores one model at a time
        self.sequence_models = {}
        self.sequence_tokenizers = {}
        self.objective_scores = []
        self.best_trend = None
        self.trend_params = None
        self.prediction_errors = []
        self.all_trend_results = {}  # Store all trend fitting results for plotting
    
    def _is_huggingface_model(self, model_name: str) -> bool:
        """Check if a model name refers to a HuggingFace model."""
        return '/' in model_name
    
    def _score_with_huggingface_model(self, prompt: str) -> float:
        """
        Score using a HuggingFace model for objective scoring.
        This loads the model specified in self.objective_model and uses it to score.
        """
        # Initialize storage for HF models if not present
        if not hasattr(self, 'hf_models'):
            self.hf_models = {}
            self.hf_tokenizers = {}
        
        model_name = self.objective_model
        
        # Load model if not cached
        if model_name not in self.hf_models:
            print(f"Loading HuggingFace objective model: {model_name}")
            self.hf_tokenizers[model_name] = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True
            )
            
            # Use quantization for large models to reduce memory usage
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4'
            )
            
            self.hf_models[model_name] = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=bnb_config,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True
            )
            
            # Set pad token if not present
            if self.hf_tokenizers[model_name].pad_token is None:
                self.hf_tokenizers[model_name].pad_token = self.hf_tokenizers[model_name].eos_token
        
        model = self.hf_models[model_name]
        tokenizer = self.hf_tokenizers[model_name]
        
        # Use the standardized scoring function
        generated_text = generate_huggingface_response(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            system_prompt=SCORING_SYSTEM_PROMPT,
            max_new_tokens=10,
            do_sample=False,
            temperature=0.1
        )
        
        # Extract score from generated text
        try:
            numbers = re.findall(r'\b\d+(?:\.\d+)?\b', generated_text.strip())
            if numbers:
                score = float(numbers[0])
                # Clamp to valid range
                return max(1.0, min(10.0, score))
            else:
                print(f"Warning: Could not extract score from '{generated_text}'. Using default.")
                return 5.0
        except Exception as e:
            print(f"Error extracting score: {e}. Using default.")
            return 5.0
    
    def _load_sequence_model(self, model_path: str):
        """
        Load a model from the training sequence (pi_theta_t).
        Only keeps one model in memory at a time to avoid GPU OOM.
        
        Args:
            model_path: Path to the model checkpoint
            
        Returns:
            Tuple of (model, tokenizer)
        """
        if model_path not in self.sequence_models:
            # Clear existing cached models to free GPU memory
            if self.sequence_models:
                print("Clearing previous sequence model from memory...")
                for cached_path in list(self.sequence_models.keys()):
                    del self.sequence_models[cached_path]
                    del self.sequence_tokenizers[cached_path]
                torch.cuda.empty_cache()
            
            print(f"Loading sequence model: {model_path}")
            
            # Check if this is an adapter model
            adapter_config_path = os.path.join(model_path, 'adapter_config.json')
            is_adapter = os.path.exists(adapter_config_path)
            
            if is_adapter:
                # Load adapter config to get base model
                import json
                with open(adapter_config_path, 'r') as f:
                    adapter_config = json.load(f)
                base_model_name = adapter_config.get('base_model_name_or_path', 'meta-llama/Llama-3.1-8B')
                
                # Check if tokenizer files exist in the adapter directory
                tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
                
                # If tokenizer was saved with the adapter, use it to ensure compatibility
                if os.path.exists(tokenizer_config_path):
                    self.sequence_tokenizers[model_path] = AutoTokenizer.from_pretrained(
                        model_path,
                        trust_remote_code=True
                    )
                else:
                    # Load tokenizer from base model
                    self.sequence_tokenizers[model_path] = AutoTokenizer.from_pretrained(
                        base_model_name,
                        trust_remote_code=True
                    )
                    
                # Add padding token if not present
                if self.sequence_tokenizers[model_path].pad_token is None:
                    self.sequence_tokenizers[model_path].add_special_tokens({"pad_token": "[PAD]"})
                
                # Load base model with quantization
                bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.bfloat16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type='nf4'
                )
                
                base_model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    quantization_config=bnb_config,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                    trust_remote_code=True
                )
                
                # Only resize embeddings if vocabulary size changed
                if len(self.sequence_tokenizers[model_path]) != base_model.config.vocab_size:
                    print(f"Resizing model embeddings from {base_model.config.vocab_size} to {len(self.sequence_tokenizers[model_path])}")
                    base_model.resize_token_embeddings(len(self.sequence_tokenizers[model_path]))
                
                # Apply adapter
                self.sequence_models[model_path] = PeftModel.from_pretrained(base_model, model_path)
            else:
                # Load full model
                self.sequence_tokenizers[model_path] = AutoTokenizer.from_pretrained(
                    model_path,
                    trust_remote_code=True
                )
                
                # Load model with reduced memory usage
                self.sequence_models[model_path] = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map="auto"
                )
            
            # Set pad token if not present
            if self.sequence_tokenizers[model_path].pad_token is None:
                self.sequence_tokenizers[model_path].pad_token = self.sequence_tokenizers[model_path].eos_token
        
        return self.sequence_models[model_path], self.sequence_tokenizers[model_path]
    
    def _generate_response(
        self,
        model: Any,
        tokenizer: Any,
        input_text,
        max_new_tokens: int = 128
    ) -> str:
        """
        Generate a response y ~ pi_theta(.|x) using a loaded model.
        
        Args:
            model: The loaded model
            tokenizer: The model's tokenizer
            input_text: Input prompt (string or list of message dicts)
            max_new_tokens: Maximum number of tokens to generate
            
        Returns:
            str: Generated response y
        """
        # Handle both string prompts and message lists
        if isinstance(input_text, list):
            # Apply chat template for message format
            input_ids = tokenizer.apply_chat_template(
                input_text,
                padding=False,
                add_generation_prompt=True,
                truncation=True,
                max_length=1024
            )
            inputs = {'input_ids': torch.tensor([input_ids], device=model.device)}
        else:
            # Handle string prompt as before
            inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        # Generate response
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode response (only the generated part)
        generated_text = tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        )
        
        return generated_text
    
    def _score_with_objective_model(
        self,
        input_text: str,
        response_text: str
    ) -> float:
        """
        Compute objective score r(x,y) using the objective model.
        
        Args:
            input_text: Input prompt x
            response_text: Model response y
            
        Returns:
            float: Objective score
        """
        print("Scoring with objective model...")
        # Check if objective_model is ObjectiveScorer instance
        if isinstance(self.objective_model, ObjectiveScorer):
            print("Using ObjectiveScorer instance for scoring.")
            # Use ObjectiveScorer's score_single_objective method
            score = self.objective_model.score_single_objective(
                input_text, 
                response_text, 
                self.objective_description
            )
        # If objective_model is a callable (custom function), use it
        elif callable(self.objective_model):
            score = self.objective_model(input_text, response_text)
        # If it's a string, treat it as a model name
        elif isinstance(self.objective_model, str):
            prompt = OBJECTIVE_SCORING_PROMPT.format(
                objective_description=self.objective_description,
                input_text=input_text,
                response_text=response_text
            )
            
            if self._is_huggingface_model(self.objective_model):
                # Score using HuggingFace model
                score = self._score_with_huggingface_model(prompt)
            else:
                # Use API model
                if self.api_client:
                    try:
                        response = self.api_client.chat.completions.create(
                            model=self.objective_model if self.objective_model.startswith("gpt") else "gpt-4o-mini",
                            messages=[
                                {"role": "system", "content": SCORING_SYSTEM_PROMPT},
                                {"role": "user", "content": prompt}
                            ],
                            temperature=0.1,
                            max_tokens=10
                        )
                        text = response.choices[0].message.content.strip()
                        # Extract score from text
                        numbers = re.findall(r'\b\d+(?:\.\d+)?\b', text)
                        score = float(numbers[0]) if numbers else 5.0
                    except Exception as e:
                        print(f"Error scoring with API: {e}")
                        score = 5.0
                else:
                    score = 5.0
        else:
            # Assume it's a model object with a scoring method
            score = self.objective_model.score(input_text, response_text)
        
        if self.use_normalized_scores:
            score = self._normalize_score(score)

        return score

    def _group_score_with_objective_model(
        self,
        input_text: str,
        responses: List[str]
    ) -> List[float]:
        """
        Score a group of responses for a single objective using group scoring.

        Args:
            input_text: Formatted input query
            responses: List of responses for this query

        Returns:
            List of scores for each response
        """
        # Ensure objective_model is an ObjectiveScorer instance
        if not isinstance(self.objective_model, ObjectiveScorer):
            raise ValueError("Group scoring requires objective_model to be an ObjectiveScorer instance")

        # Use group scoring for all responses at once
        scores = self.objective_model.group_score_single_objective(
            query=input_text,
            responses=responses,
            objective=self.objective_description
        )

        if self.use_normalized_scores:
            scores = [self._normalize_score(s) for s in scores]

        return scores

    async def _async_score_with_objective_model(
        self,
        input_text: str,
        response_text: str
    ) -> float:
        """
        Async version of _score_with_objective_model using ObjectiveScorer.async_score_single_objective.

        Args:
            input_text: Input prompt x
            response_text: Model response y

        Returns:
            float: Objective score
        """
        # Ensure objective_model is an ObjectiveScorer instance with async capability
        if not isinstance(self.objective_model, ObjectiveScorer):
            use_api = self.objective_model.startswith("gpt") if isinstance(self.objective_model, str) else False
            self.objective_model = ObjectiveScorer(
                use_detailed_rubric=True,
                dataset_type=self.dataset_type,
                use_api=use_api,
                model_name=self.objective_model if isinstance(self.objective_model, str) else "gpt-4o-mini",
                max_length=4096,
                load_quantized=not use_api,
                cache_dir=None,
                save_dir=self.output_dir,
                # use_async=True
            )

        # Use async scoring
        score = await self.objective_model.async_score_single_objective(
            input_text,
            response_text,
            self.objective_description
        )

        if self.use_normalized_scores:
            score = self._normalize_score(score)

        return score

    async def _async_group_score_with_objective_model(
        self,
        input_text: str,
        responses: List[str]
    ) -> List[float]:
        """
        Async version of _group_score_with_objective_model.

        Args:
            input_text: Formatted input query
            responses: List of responses for this query

        Returns:
            List of scores for each response
        """
        # Ensure objective_model is an ObjectiveScorer instance
        if not isinstance(self.objective_model, ObjectiveScorer):
            use_api = self.objective_model.startswith("gpt") if isinstance(self.objective_model, str) else False
            self.objective_model = ObjectiveScorer(
                use_detailed_rubric=True,
                dataset_type=self.dataset_type,
                use_api=use_api,
                model_name=self.objective_model if isinstance(self.objective_model, str) else "gpt-4o-mini",
                max_length=4096,
                load_quantized=not use_api,
                cache_dir=None,
                save_dir=self.output_dir,
                # use_async=True
            )

        # Use async group scoring
        scores = await self.objective_model.async_group_score_single_objective(
            query=input_text,
            responses=responses,
            objective=self.objective_description
        )

        if self.use_normalized_scores:
            scores = [self._normalize_score(s) for s in scores]

        return scores

    def _compute_v_scores_async(
        self,
        all_inputs: List[str],
        all_responses: List[str],
        max_concurrent: int = 10
    ) -> List[float]:
        """
        Compute objective scores for all (input, response) pairs in parallel.

        Args:
            all_inputs: List of formatted input strings
            all_responses: List of response strings
            max_concurrent: Maximum concurrent API calls (default: 50)

        Returns:
            List of objective scores
        """
        scoring_start_time = time.time()

        # Check if we can use async
        if not isinstance(self.objective_model, ObjectiveScorer):
            use_api = self.objective_model.startswith("gpt") if isinstance(self.objective_model, str) else False
            self.objective_model = ObjectiveScorer(
                use_detailed_rubric=True,
                dataset_type=self.dataset_type,
                use_api=use_api,
                model_name=self.objective_model if isinstance(self.objective_model, str) else "gpt-4o-mini",
                max_length=4096,
                load_quantized=not use_api,
                cache_dir=None,
                save_dir=self.output_dir,
                # use_async=True
            )

        use_async = hasattr(self.objective_model, 'async_client') and self.objective_model.use_api

        if not use_async:
            # Fallback to sequential scoring
            print("Async not available, falling back to sequential scoring...")
            scores = []
            for idx, (input_str, response) in enumerate(zip(all_inputs, all_responses)):
                score = self._score_with_objective_model(input_str, response)
                scores.append(score)
                if (idx + 1) % 10 == 0:
                    print(f"    Scored {idx + 1}/{len(all_inputs)} pairs")
            return scores

        num_samples = len(all_inputs)
        print(f"  Scoring {num_samples} samples in parallel (max concurrent: {max_concurrent})...")

        async def run_all_scoring_parallel():
            """Run ALL scoring tasks in parallel with rate limiting."""
            semaphore = asyncio.Semaphore(max_concurrent)

            async def score_sample(idx: int, input_str: str, response: str):
                """Score a single (input, response) pair."""
                async with semaphore:
                    score = await self._async_score_with_objective_model(input_str, response)
                    return idx, score

            # Create tasks for all samples
            tasks = [
                score_sample(idx, input_str, response)
                for idx, (input_str, response) in enumerate(zip(all_inputs, all_responses))
            ]

            # Run all tasks in parallel (rate-limited by semaphore)
            results = await asyncio.gather(*tasks)

            # Sort results by index and extract scores
            results.sort(key=lambda x: x[0])
            return [r[1] for r in results]

        # Run the async scoring
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = None

        if loop and loop.is_running():
            print("Running in an async context, applying nest_asyncio...")
            import nest_asyncio
            nest_asyncio.apply()
            scores = asyncio.get_event_loop().run_until_complete(run_all_scoring_parallel())
        else:
            scores = asyncio.run(run_all_scoring_parallel())

        scoring_time = time.time() - scoring_start_time
        print(f"    Scoring completed in {scoring_time:.2f}s ({num_samples / scoring_time:.1f} samples/sec)")

        return scores

    def _compute_group_v_scores_async(
        self,
        all_formatted_prompts: List[str],
        all_responses_by_prompt: List[List[str]],
        max_concurrent: int = 50
    ) -> List[List[float]]:
        """
        Compute group scores for all prompts in parallel.

        Args:
            all_formatted_prompts: List of formatted prompt strings
            all_responses_by_prompt: List of lists, where each inner list contains responses from all models
            max_concurrent: Maximum concurrent API calls (default: 50)

        Returns:
            List of score lists, where each inner list contains scores for all models for that prompt
        """
        scoring_start_time = time.time()

        # Ensure objective_model is an ObjectiveScorer instance
        if not isinstance(self.objective_model, ObjectiveScorer):
            use_api = self.objective_model.startswith("gpt") if isinstance(self.objective_model, str) else False
            self.objective_model = ObjectiveScorer(
                use_detailed_rubric=True,
                dataset_type=self.dataset_type,
                use_api=use_api,
                model_name=self.objective_model if isinstance(self.objective_model, str) else "gpt-4o-mini",
                max_length=4096,
                load_quantized=not use_api,
                cache_dir=None,
                save_dir=self.output_dir,
                # use_async=True
            )

        use_async = hasattr(self.objective_model, 'async_client') and self.objective_model.use_api

        if not use_async:
            # Fallback to sequential scoring
            print("Async not available, falling back to sequential group scoring...")
            all_scores = []
            for prompt_idx, (formatted_prompt, responses) in enumerate(
                zip(all_formatted_prompts, all_responses_by_prompt)
            ):
                scores = self._group_score_with_objective_model(formatted_prompt, responses)
                all_scores.append(scores)
                if (prompt_idx + 1) % 10 == 0:
                    print(f"  Scored {prompt_idx + 1}/{len(all_formatted_prompts)} prompts")
            return all_scores

        num_prompts = len(all_formatted_prompts)
        print(f"  Scoring {num_prompts} prompts in parallel (max concurrent: {max_concurrent})...")

        async def run_all_group_scoring_parallel():
            """Run ALL group scoring tasks in parallel with rate limiting."""
            semaphore = asyncio.Semaphore(max_concurrent)

            async def score_prompt(idx: int, formatted_prompt: str, responses: List[str]):
                """Score all responses for a single prompt."""
                async with semaphore:
                    scores = await self._async_group_score_with_objective_model(formatted_prompt, responses)
                    return idx, scores

            # Create tasks for all prompts
            tasks = [
                score_prompt(idx, formatted_prompt, responses)
                for idx, (formatted_prompt, responses) in enumerate(
                    zip(all_formatted_prompts, all_responses_by_prompt)
                )
            ]

            # Run all tasks in parallel (rate-limited by semaphore)
            results = await asyncio.gather(*tasks)

            # Sort results by index
            results.sort(key=lambda x: x[0])
            return [r[1] for r in results]

        # Run the async scoring
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = None

        if loop and loop.is_running():
            print("Running in an async context, applying nest_asyncio...")
            import nest_asyncio
            nest_asyncio.apply()
            all_scores = asyncio.get_event_loop().run_until_complete(run_all_group_scoring_parallel())
        else:
            all_scores = asyncio.run(run_all_group_scoring_parallel())

        scoring_time = time.time() - scoring_start_time
        print(f"    Group scoring completed in {scoring_time:.2f}s ({num_prompts / scoring_time:.1f} prompts/sec)")

        return all_scores

    def group_compute_v_scores(
        self,
        dataset: List[Dict[str, str]],
        sample_size: Optional[int] = None,
        batch_size: int = 8
    ) -> List[float]:
        """
        Compute V_n^t(r) for each model in the sequence using group scoring.

        This implements Equation (2) from Methods.md with group scoring:
        V_n^t(r) = E_{x ~ X, y ~ pi_theta_t(.|x)} [r_n(x, y)]

        Group scoring provides better calibrated scores by evaluating all models'
        responses together for each prompt.

        Args:
            dataset: List of samples with 'input' keys
            sample_size: Optional number of samples to use per model
            batch_size: Batch size for generation

        Returns:
            List of V scores for each model in the sequence
        """
        if not self.model_sequence:
            raise ValueError("No model sequence provided")

        # Sample dataset if needed
        if sample_size and len(dataset) > sample_size:
            dataset = random.sample(dataset, sample_size)

        v_scores = []

        print(f"Computing V scores across {len(self.model_sequence)} models using group scoring...")

        # Store all responses for each prompt across all models
        all_responses_by_prompt = [[] for _ in dataset]
        all_formatted_prompts = []

        # Collect responses from all models first
        for t, model_path in enumerate(self.model_sequence, 1):
            print(f"\nCollecting responses from model {t}/{len(self.model_sequence)}: {model_path}")

            # Load model once for this checkpoint
            model, tokenizer = self._load_sequence_model(model_path)

            # Collect all prompts for batched generation
            prompts = []
            for sample in dataset:
                input_text = sample.get('input', sample.get('prompt', ''))
                prompts.append(input_text)

            # Generate all responses in batch
            print(f"  Generating {len(prompts)} responses in batch...")
            responses = generate_responses_batched(
                model_path=model_path,
                model=model,
                tokenizer=tokenizer,
                prompts=prompts,
                max_new_tokens=512,
                batch_size=batch_size,
                temperature=0.7,
                top_p=0.9
            )

            # Store responses and format prompts (only once for first model)
            for idx, (prompt, response) in enumerate(zip(prompts, responses)):
                all_responses_by_prompt[idx].append(response)

                if t == 1:  # Only format prompts once
                    # Properly format multi-turn conversations using chat template
                    if isinstance(prompt, list):
                        input_str = tokenizer.apply_chat_template(
                            prompt,
                            tokenize=False,
                            add_generation_prompt=False
                        )
                    else:
                        input_str = tokenizer.apply_chat_template(
                            [{"role": "user", "content": prompt}],
                            tokenize=False,
                            add_generation_prompt=False
                        )
                    all_formatted_prompts.append(input_str)

            # Clean up memory after each model
            del model
            del tokenizer
            if model_path in self.sequence_models:
                del self.sequence_models[model_path]
                del self.sequence_tokenizers[model_path]
            torch.cuda.empty_cache()

        # Now score all responses using parallel group scoring
        print(f"\nScoring all collected responses using parallel group scoring...")
        all_scores_by_prompt = self._compute_group_v_scores_async(
            all_formatted_prompts,
            all_responses_by_prompt,
            max_concurrent=self.max_concurrent
        )

        # Reorganize scores: from [prompt_idx][model_idx] to [model_idx][prompt_idx]
        all_scores_by_model = [[] for _ in self.model_sequence]
        for scores_for_prompt in all_scores_by_prompt:
            for model_idx, score in enumerate(scores_for_prompt):
                all_scores_by_model[model_idx].append(score)

        # Calculate V scores for each model
        for t, model_scores in enumerate(all_scores_by_model, 1):
            v_t = float(np.mean(model_scores))
            v_scores.append(v_t)
            print(f"  V^{t}(r) = {v_t:.4f}")

        self.objective_scores = v_scores
        return v_scores

    def compute_v_scores(
        self,
        dataset: List[Dict[str, str]],
        sample_size: Optional[int] = None,
        batch_size: int = 8
    ) -> List[float]:
        """
        Compute V_n^t(r) for each model in the sequence.

        This implements Equation (2) from Methods.md:
        V_n^t(r) = E_{x ~ X, y ~ pi_theta_t(.|x)} [r_n(x, y)]

        Args:
            dataset: List of samples with 'input' keys
            sample_size: Optional number of samples to use per model
            batch_size: Batch size for generation

        Returns:
            List of V scores for each model in the sequence
        """
        # Use group scoring if enabled
        if self.group_scoring:
            return self.group_compute_v_scores(dataset, sample_size, batch_size)

        if not self.model_sequence:
            raise ValueError("No model sequence provided")

        # Sample dataset if needed
        if sample_size and len(dataset) > sample_size:
            dataset = random.sample(dataset, sample_size)

        v_scores = []

        print(f"Computing V scores across {len(self.model_sequence)} models...")

        for t, model_path in enumerate(self.model_sequence, 1):
            print(f"\nProcessing model {t}/{len(self.model_sequence)}: {model_path}")

            # Load model once for this checkpoint
            model, tokenizer = self._load_sequence_model(model_path)

            # Collect all prompts for batched generation
            prompts = []
            for sample in dataset:
                input_text = sample.get('input', sample.get('prompt', ''))
                prompts.append(input_text)
            
            # Generate all responses in batch
            print(f"  Generating {len(prompts)} responses in batch...")
            responses = generate_responses_batched(
                model_path=model_path,
                model=model,
                tokenizer=tokenizer,
                prompts=prompts,
                max_new_tokens=512,
                batch_size=batch_size,
                temperature=0.7,
                top_p=0.9
            )

            # Prepare formatted inputs for scoring
            print(f"  Formatting inputs for scoring...")
            all_input_strs = []
            for prompt in prompts:
                if isinstance(prompt, list):
                    input_str = tokenizer.apply_chat_template(
                        prompt,
                        tokenize=False,
                        add_generation_prompt=False
                    )
                else:
                    input_str = tokenizer.apply_chat_template(
                        [{"role": "user", "content": prompt}],
                        tokenize=False,
                        add_generation_prompt=False
                    )
                all_input_strs.append(input_str)

            # Score all responses in parallel using async helper
            print(f"  Scoring responses (parallel)...")
            model_scores = self._compute_v_scores_async(all_input_strs, responses, max_concurrent=self.max_concurrent)

            # V_n^t(r) is the expected (mean) score for this model
            v_t = float(np.mean(model_scores))
            v_scores.append(v_t)

            print(f"  V^{t}(r) = {v_t:.4f}")

            # Clean up memory
            del model
            del tokenizer
            if model_path in self.sequence_models:
                del self.sequence_models[model_path]
                del self.sequence_tokenizers[model_path]
            torch.cuda.empty_cache()

        self.objective_scores = v_scores
        return v_scores
    
    def fit_optimal_trend(
        self,
        v_scores: List[float]
    ) -> Tuple[str, np.ndarray, Dict[str, float], float]:
        """
        Find the optimal trend function f* that minimizes prediction error.
        
        This implements Equation (3) from Methods.md:
        f* = argmin_{f in F_trend} sum_{t=2}^T L(f(t|V^{t-1}(r)), V^t(r))
        
        Args:
            v_scores: Sequence of V scores from compute_v_scores
            
        Returns:
            Tuple of (best_trend_type, parameters, param_dict, avg_prediction_error)
        """
        if len(v_scores) < 2:
            raise ValueError("Need at least 2 scores to fit a trend")
        
        T = len(v_scores)
        t_values = np.arange(1, T + 1)
        y_values = np.array(v_scores)
        
        print(f"\nFitting trend functions to {T} V scores...")
        print(f"V scores: {[f'{v:.3f}' for v in v_scores]}")
        
        best_trend_type = None
        best_params = None
        best_param_dict = None
        best_avg_error = float('inf')
        
        # Clear previous results
        self.all_trend_results = {}
        self.all_timestep_results = {}  # Store results for each timestep

        # Try each trend function type
        for trend_type in self.trend_function_types:
            try:
                # Calculate prediction error as in Equation (3)
                # For each t from 2 to T, fit trend using data up to t-1 and predict V^t
                prediction_errors = []
                trend_func = TREND_FUNCTIONS[trend_type]['func']
                timestep_fits = {}  # Store fits for each timestep

                for t in range(3, T + 1): # want to start with 3 to have at least 2 points to fit
                    # Fit trend using only values up to t-1
                    t_train = t_values[:t-1]  # Indices 1 to t-1
                    y_train = y_values[:t-1]  # V^1 to V^{t-1}

                    # Fit the trend to data up to t-1
                    params, param_dict, _ = fit_trend(t_train, y_train, trend_type)

                    # Store the fit for this timestep
                    timestep_fits[t] = {
                        'params': params,
                        'param_dict': param_dict,
                        't_train': t_train,
                        'y_train': y_train
                    }

                    # # Predict V^t based on fitted trend
                    # predicted = trend_func(np.array([t]), *params)[0]
                    # actual = v_scores[t - 1]  # t-1 because of 0-indexing

                    # # Calculate loss (using squared error)
                    # # error = (predicted - actual) ** 2
                    # error = abs(predicted - actual) # absolute error
                    # prediction_errors.append(error)

                    # Predict at all indices up to and including t
                    t_predict = t_values[:t]  # Indices 1 to t
                    predicted = trend_func(t_predict, *params)
                    actual = y_values[:t]  # V^1 to V^t

                    # # Calculate mean absolute error at all points up to t
                    # error = np.mean(np.abs(predicted - actual))
                    # prediction_errors.append(error)

                    # Calculate RMSE at all points up to t
                    error = np.sqrt(np.mean((predicted - actual) ** 2))
                    prediction_errors.append(error)

                # Also fit using all data for final params
                final_params, final_param_dict, _ = fit_trend(t_values, y_values, trend_type)

                # Average prediction error as in Equation (4)
                avg_error = float(np.mean(prediction_errors)) if prediction_errors else 0.0

                # Store results for this trend
                self.all_trend_results[trend_type] = {
                    'params': final_params,
                    'param_dict': final_param_dict,
                    'avg_error': avg_error,
                    'function': trend_func
                }

                # Store timestep-specific fits
                self.all_timestep_results[trend_type] = timestep_fits
                
                print(f"  {trend_type}: avg_error = {avg_error:.6f}, params = {param_dict}")
                
                # Check if this is the best so far
                if avg_error < best_avg_error:
                    best_trend_type = trend_type
                    best_params = params
                    best_param_dict = param_dict
                    best_avg_error = avg_error
                    
            except Exception as e:
                print(f"  Failed to fit {trend_type}: {e}")
                continue
        
        print(f"\nBest trend: {best_trend_type} with avg error: {best_avg_error:.6f}")
        
        # Store results
        self.best_trend = best_trend_type
        self.trend_params = best_params
        
        # Save plots if save_dir is specified
        if self.save_dir and len(self.all_trend_results) > 0:
            self._save_trend_plots(t_values, y_values)
        
        return best_trend_type, best_params, best_param_dict, best_avg_error
    
    def _save_trend_plots(
        self,
        t_values: np.ndarray,
        y_values: np.ndarray
    ):
        """
        Save plots of all fitted trends to the specified directory.
        Now saves plots for each timestep showing the trend fitted up to that point.

        Args:
            t_values: Time/iteration indices
            y_values: Actual V scores
        """
        # Create directory for plots
        import datetime
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        objective_name_clean = re.sub(r'[^\w\s-]', '', self.objective_description)[:50]
        objective_name_clean = re.sub(r'[-\s]+', '_', objective_name_clean)

        plot_dir = os.path.join(
            self.save_dir,
            f"trend_plots_{objective_name_clean}_{timestamp}"
        )
        os.makedirs(plot_dir, exist_ok=True)

        print(f"\nSaving trend plots to: {plot_dir}")

        # First save timestep-specific plots
        if hasattr(self, 'all_timestep_results'):
            for trend_type, timestep_fits in self.all_timestep_results.items():
                # Create subdirectory for this trend type
                trend_dir = os.path.join(plot_dir, f"{trend_type}_timesteps")
                os.makedirs(trend_dir, exist_ok=True)

                for t, fit_data in timestep_fits.items():
                    fig, ax = plt.subplots(figsize=(10, 6))

                    # Plot all data points up to current timestep
                    ax.scatter(t_values[:t], y_values[:t], color='black', s=50, zorder=5,
                              label='Observed V scores', alpha=0.8)

                    # Plot future points in gray (not yet observed at this timestep)
                    if t < len(t_values):
                        ax.scatter(t_values[t:], y_values[t:], color='gray', s=30, zorder=3,
                                  label='Future V scores', alpha=0.4)

                    # Plot fitted trend line based on data up to t-1
                    trend_func = TREND_FUNCTIONS[trend_type]['func']
                    params = fit_data['params']

                    # Generate smooth line for plotting
                    t_plot = np.linspace(1, max(t_values[-1], t), 100)
                    y_fitted = trend_func(t_plot, *params)

                    ax.plot(t_plot, y_fitted, color='blue', linewidth=2,
                           label=f'{trend_type} fit (using t=1 to {t-1})', alpha=0.8)

                    # Calculate predictions and residuals for all points up to and including t
                    t_all = t_values[:t]  # Indices 1 to t
                    y_pred_all = trend_func(t_all, *params)
                    y_actual_all = y_values[:t]  # V^1 to V^t
                    residuals = np.abs(y_pred_all - y_actual_all)
                    # avg_error = np.mean(residuals)  # MAE
                    avg_error = np.sqrt(np.mean((y_pred_all - y_actual_all) ** 2))  # RMSE

                    # Draw error lines and annotate residuals for all points up to t
                    for i in range(t):
                        ti = t_values[i]
                        ax.vlines(ti, y_actual_all[i], y_pred_all[i], colors='red', alpha=0.5,
                                 linestyles='dashed')
                        mid_y = (y_actual_all[i] + y_pred_all[i]) / 2
                        ax.text(ti + 0.1, mid_y, f'{residuals[i]:.3f}', fontsize=8, color='red')

                    # Add legend entry for residual lines
                    ax.plot([], [], color='red', linestyle='dashed', alpha=0.5, label='Residuals')

                    # Highlight the prediction for timestep t
                    t_pred = y_pred_all[-1]
                    ax.scatter([t], [t_pred], color='red', s=100, marker='*', zorder=6,
                              label=f'Prediction for t={t}')

                    # Formatting
                    ax.set_xlabel('Model Iteration (t)', fontsize=12)
                    ax.set_ylabel('V^t(r) Score', fontsize=12)
                    ax.set_title(f'{trend_type} Trend - Timestep {t}\nFitted on t=1 to {t-1}',
                                fontsize=14)
                    ax.set_xticks(t_values)
                    ax.set_xticklabels([int(t) for t in t_values])
                    ax.grid(True, alpha=0.3)
                    ax.legend(loc='best')

                    # Add parameters text
                    param_text = ', '.join([f'{k}={v:.3f}' for k, v in fit_data['param_dict'].items()])
                    prediction_error = abs(t_pred - y_values[t-1])
                    stats_text = f'Params: {param_text}\nPrediction Error: {prediction_error:.6f}\nRMSE: {avg_error:.6f}'
                    ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
                           fontsize=10, verticalalignment='top',
                           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

                    # Save timestep plot
                    filename = f"t{t:02d}_fit.png"
                    filepath = os.path.join(trend_dir, filename)
                    plt.tight_layout()
                    plt.savefig(filepath, dpi=150, bbox_inches='tight')
                    plt.close()

                print(f"  Saved timestep plots for {trend_type} in {trend_type}_timesteps/")

        # Save trend parameters to JSON files for analysis
        if hasattr(self, 'all_timestep_results'):
            # Save timestep-specific parameters
            timestep_params_file = os.path.join(plot_dir, "timestep_trend_params.json")
            timestep_params_data = {}

            for trend_type, timestep_fits in self.all_timestep_results.items():
                timestep_params_data[trend_type] = {}
                for t, fit_data in timestep_fits.items():
                    timestep_params_data[trend_type][str(t)] = {
                        'param_dict': fit_data['param_dict'],
                        'train_size': len(fit_data['t_train']),
                        'train_range': [float(fit_data['t_train'][0]), float(fit_data['t_train'][-1])]
                    }

            with open(timestep_params_file, 'w') as f:
                json.dump(timestep_params_data, f, indent=2)
            print(f"  Saved timestep parameters to: timestep_trend_params.json")

            # Save final fit parameters and errors
            final_params_file = os.path.join(plot_dir, "final_trend_params.json")
            final_params_data = {}

            for trend_type, result in self.all_trend_results.items():
                final_params_data[trend_type] = {
                    'param_dict': result['param_dict'],
                    'avg_prediction_error': result['avg_error'],
                    'is_best': trend_type == self.best_trend
                }

            with open(final_params_file, 'w') as f:
                json.dump(final_params_data, f, indent=2)
            print(f"  Saved final parameters to: final_trend_params.json")

        # Generate extended t values for smoother plotting
        t_extended = np.linspace(t_values[0], t_values[-1], 100)

        # Create individual plots for each trend (final fit using all data)
        for trend_type, result in self.all_trend_results.items():
            fig, ax = plt.subplots(figsize=(10, 6))
            
            # Plot actual data points
            ax.scatter(t_values, y_values, color='black', s=50, zorder=5, 
                      label='Actual V scores', alpha=0.8)
            
            # Plot fitted trend line
            trend_func = result['function']
            params = result['params']
            y_fitted = trend_func(t_extended, *params)
            
            # Choose color based on whether this is the best trend
            color = 'red' if trend_type == self.best_trend else 'blue'
            linewidth = 2.5 if trend_type == self.best_trend else 1.5
            
            ax.plot(t_extended, y_fitted, color=color, linewidth=linewidth,
                   label=f'{trend_type} fit', alpha=0.8)
            
            # Add prediction points for visualization
            y_pred_points = trend_func(t_values, *params)
            ax.scatter(t_values, y_pred_points, color=color, s=20, alpha=0.5,
                      marker='x', label='Predictions')
            
            # Add error bars
            errors = np.abs(y_values - y_pred_points)
            ax.vlines(t_values, y_values, y_pred_points, colors='gray', 
                     alpha=0.3, linestyles='dashed')
            
            # Formatting
            ax.set_xlabel('Model Iteration (t)', fontsize=12)
            ax.set_ylabel('V^t(r) Score', fontsize=12)
            ax.set_title(f'Trend Fit: {trend_type}\nObjective: {self.objective_description[:60]}...'
                        if len(self.objective_description) > 60 else
                        f'Trend Fit: {trend_type}\nObjective: {self.objective_description}',
                        fontsize=14)
            # Set x-axis to show only integer ticks
            ax.set_xticks(t_values)
            ax.set_xticklabels([int(t) for t in t_values])
            ax.grid(True, alpha=0.3)
            ax.legend(loc='best')
            
            # Add statistics text
            avg_error = result['avg_error']
            param_text = ', '.join([f'{k}={v:.3f}' for k, v in result['param_dict'].items()])
            stats_text = f'Avg Error: {avg_error:.6f}\nParams: {param_text}'
            ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
                   fontsize=10, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
            # Save individual plot
            filename = f"{trend_type}_fit.png"
            filepath = os.path.join(plot_dir, filename)
            plt.tight_layout()
            plt.savefig(filepath, dpi=150, bbox_inches='tight')
            plt.close()
            
            print(f"  Saved: {filename}")
        
        # Create combined comparison plot
        fig, ax = plt.subplots(figsize=(12, 8))
        
        # Plot actual data points
        ax.scatter(t_values, y_values, color='black', s=60, zorder=5,
                  label='Actual V scores', alpha=0.9)
        
        # Plot all trend lines
        colors = plt.cm.Set1(np.linspace(0, 1, len(self.all_trend_results)))
        for i, (trend_type, result) in enumerate(self.all_trend_results.items()):
            trend_func = result['function']
            params = result['params']
            y_fitted = trend_func(t_extended, *params)
            
            linewidth = 2.5 if trend_type == self.best_trend else 1.5
            linestyle = '-' if trend_type == self.best_trend else '--'
            
            label = f'{trend_type} (err={result["avg_error"]:.4f})'
            if trend_type == self.best_trend:
                label += ' [BEST]'
            
            ax.plot(t_extended, y_fitted, color=colors[i], linewidth=linewidth,
                   linestyle=linestyle, label=label, alpha=0.8)
        
        # Formatting
        ax.set_xlabel('Model Iteration (t)', fontsize=12)
        ax.set_ylabel('V^t(r) Score', fontsize=12)
        ax.set_title(f'All Trend Fits Comparison\nObjective: {self.objective_description[:60]}...'
                    if len(self.objective_description) > 60 else
                    f'All Trend Fits Comparison\nObjective: {self.objective_description}',
                    fontsize=14)
        # Set x-axis to show only integer ticks
        ax.set_xticks(t_values)
        ax.set_xticklabels([int(t) for t in t_values])
        ax.grid(True, alpha=0.3)
        ax.legend(loc='best', fontsize=10)
        
        # Save combined plot
        filename = "all_trends_comparison.png"
        filepath = os.path.join(plot_dir, filename)
        plt.tight_layout()
        plt.savefig(filepath, dpi=150, bbox_inches='tight')
        plt.close()
        
        print(f"  Saved: {filename} (comparison plot)")
        
        # Create residual plot for best trend
        if self.best_trend:
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), 
                                          gridspec_kw={'height_ratios': [2, 1]})
            
            # Top plot: Best fit
            best_result = self.all_trend_results[self.best_trend]
            trend_func = best_result['function']
            params = best_result['params']
            
            ax1.scatter(t_values, y_values, color='black', s=50, zorder=5,
                       label='Actual V scores', alpha=0.8)
            y_fitted = trend_func(t_extended, *params)
            ax1.plot(t_extended, y_fitted, color='red', linewidth=2,
                    label=f'{self.best_trend} fit (best)', alpha=0.8)
            
            ax1.set_ylabel('V^t(r) Score', fontsize=12)
            ax1.set_title(f'Best Trend Fit with Residuals\nObjective: {self.objective_description[:60]}...'
                         if len(self.objective_description) > 60 else
                         f'Best Trend Fit with Residuals\nObjective: {self.objective_description}',
                         fontsize=14)
            # Set x-axis to show only integer ticks
            ax1.set_xticks(t_values)
            ax1.set_xticklabels([int(t) for t in t_values])
            ax1.grid(True, alpha=0.3)
            ax1.legend(loc='best')
            
            # Bottom plot: Residuals
            y_pred_points = trend_func(t_values, *params)
            residuals = y_values - y_pred_points
            
            ax2.scatter(t_values, residuals, color='blue', s=30, alpha=0.7)
            ax2.axhline(y=0, color='red', linestyle='--', alpha=0.5)
            ax2.set_xlabel('Model Iteration (t)', fontsize=12)
            ax2.set_ylabel('Residuals', fontsize=12)
            # Set x-axis to show only integer ticks
            ax2.set_xticks(t_values)
            ax2.set_xticklabels([int(t) for t in t_values])
            ax2.grid(True, alpha=0.3)
            
            # Add residual statistics
            residual_std = np.std(residuals)
            ax2.fill_between(t_values, -residual_std, residual_std, 
                            color='gray', alpha=0.2, label=f'±1 std ({residual_std:.4f})')
            ax2.legend(loc='best')
            
            # Save residual plot
            filename = f"{self.best_trend}_residuals.png"
            filepath = os.path.join(plot_dir, filename)
            plt.tight_layout()
            plt.savefig(filepath, dpi=150, bbox_inches='tight')
            plt.close()
            
            print(f"  Saved: {filename}")
        
        print(f"\nAll plots saved to: {plot_dir}")
    
    def check_predictable_trend_criteria(
        self,
        avg_prediction_error: float
    ) -> bool:
        """
        Check if the predictable trend criteria is satisfied.
        
        This implements Equation (4) from Methods.md:
        (1/(T-1)) * sum_{t=2}^T L(f*(t|V^{t-1}(r)), V^t(r)) <= epsilon_trend
        
        Args:
            avg_prediction_error: Average prediction error from fit_optimal_trend
            
        Returns:
            bool: True if criteria is satisfied, False otherwise
        """
        is_predictable = avg_prediction_error <= self.epsilon
        
        print(f"\nPredictable Trend Criteria Check:")
        print(f"  Average prediction error: {avg_prediction_error:.6f}")
        print(f"  Threshold (epsilon): {self.epsilon}")
        print(f"  Satisfies criteria: {is_predictable}")
        
        return is_predictable
    
    def verify(
        self,
        dataset: Union[List[Dict[str, str]], str],
        sample_size: Optional[int] = None,
        v_scores: Optional[List[float]] = None
    ) -> bool:
        """
        Verify if the objective follows a predictable trend.
        
        This orchestrates the full verification process:
        1. Compute V scores for each model (or use provided scores)
        2. Find optimal trend function f*
        3. Check if predictable trend criteria is satisfied
        
        Args:
            dataset: Dataset to evaluate on (if v_scores not provided)
            sample_size: Optional number of samples to use per model
            v_scores: Pre-computed V scores (if available)
            
        Returns:
            bool: True if objective follows a predictable trend, False otherwise
        """
        print(f"\n{'='*60}")
        print(f"Predictable Trend Verification")
        print(f"Objective: {self.objective_description}")
        print(f"{'='*60}")
        
        # Step 1: Compute V scores if not provided
        if v_scores is None:
            if isinstance(dataset, str):
                # Handle loading from HuggingFace datasets or local files
                if '/' in dataset:
                    ds = load_dataset(dataset, split='train')
                    dataset = [
                        {'input': sample.get('prompt', sample.get('input', ''))}
                        for sample in ds
                    ]
                else:
                    raise NotImplementedError("Local file loading not implemented")
            
            v_scores = self.compute_v_scores(dataset, sample_size)
        else:
            self.objective_scores = v_scores
            print(f"Using provided V scores: {[f'{v:.3f}' for v in v_scores]}")
        
        # Step 2: Find optimal trend function
        best_trend, params, param_dict, avg_error = self.fit_optimal_trend(v_scores)
        
        # Step 3: Check if criteria is satisfied
        is_predictable = self.check_predictable_trend_criteria(avg_error)
        
        print(f"\n{'='*60}")
        print(f"Final Result: {'PASS' if is_predictable else 'FAIL'}")
        print(f"{'='*60}\n")
        
        return is_predictable
    
    def analyze_trend(
        self,
        v_scores: Optional[List[float]] = None
    ) -> Dict[str, Any]:
        """
        Analyze the trend in a sequence of scores.
        
        Args:
            v_scores: Sequence of objective scores (uses stored scores if None)
            
        Returns:
            Dict containing comprehensive trend analysis results
        """
        if v_scores is None:
            v_scores = self.objective_scores
        
        if not v_scores:
            return {
                "trend_type": "unknown",
                "direction": "unknown",
                "stability": 0.0,
                "predictability": 0.0,
                "parameters": {}
            }
        
        # Analyze basic statistics
        v_array = np.array(v_scores)
        mean_score = np.mean(v_array)
        std_score = np.std(v_array)
        
        # Determine direction
        if len(v_scores) > 1:
            slope = (v_scores[-1] - v_scores[0]) / (len(v_scores) - 1)
            if abs(slope) < 0.01:
                direction = "flat"
            elif slope > 0:
                direction = "increasing"
            else:
                direction = "decreasing"
        else:
            direction = "unknown"
        
        # Calculate stability (inverse of coefficient of variation)
        stability = 1.0 / (1.0 + std_score / mean_score) if mean_score > 0 else 0
        
        # Predictability based on best fit
        if self.best_trend and self.trend_params is not None:
            t_values = np.arange(1, len(v_scores) + 1)
            trend_func = TREND_FUNCTIONS[self.best_trend]['func']
            metrics = evaluate_trend_fit(t_values, v_array, trend_func, self.trend_params)
            predictability = metrics['r_squared']
        else:
            predictability = 0.0
        
        return {
            "trend_type": self.best_trend if self.best_trend else "unknown",
            "direction": direction,
            "stability": stability,
            "predictability": predictability,
            "parameters": dict(zip(
                TREND_FUNCTIONS[self.best_trend]['params'],
                self.trend_params
            )) if self.best_trend and self.trend_params is not None else {},
            "mean_score": mean_score,
            "std_score": std_score,
            "min_score": np.min(v_array),
            "max_score": np.max(v_array),
            "scores": v_scores
        }
