"""
Usefulness Metric

Handles Chain-of-Thought (CoT) perturbation evaluation to measure 
the usefulness of reasoning traces through various perturbation techniques.
"""

import json
import logging
import numpy as np
import os
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime

from .base_metric import BaseMetric

logger = logging.getLogger(__name__)


class UsefulnessMetric(BaseMetric):
    """Handles CoT perturbation evaluation to measure reasoning trace usefulness.
    
    Combines original CoT perturbation functionality with checkpoint tracking
    and analysis capabilities for comprehensive usefulness evaluation.
    """
    
    def __init__(self, config: Dict[str, Any], reward_calculator, prompt_manager):
        """Initialize usefulness metric with checkpoint tracking."""
        super().__init__(config, reward_calculator, prompt_manager)
        
        # Initialize checkpoint tracking from BaseMetrics functionality
        self.checkpoint_results = {}
        self.aggregated_metrics = {}
    
    def _init_metric_config(self):
        """Initialize usefulness metric configuration."""
        requested_metrics = self.config.get('evaluation', {}).get('metrics', [])

        # Only supported perturbation types: cot_importance and expert_thinking (expert perplexity)
        self.cot_types = ['cot_importance', 'expert_thinking']

        # Perturbations with special handling (not using standard perturb_teacher_responses flow)
        self.special_perturbations = {'cot_importance'}

        self.requested_cot_metrics = [r for r in requested_metrics if r in self.cot_types]

        # Split into standard and special perturbations
        self.standard_perturbations = [m for m in self.requested_cot_metrics if m not in self.special_perturbations]
        self.special_requested = [m for m in self.requested_cot_metrics if m in self.special_perturbations]
    
    def can_run(self) -> bool:
        """Check if CoT perturbation evaluation should be run."""
        return len(self.requested_cot_metrics) > 0
    
    def evaluate(self, teacher_responses: List[Dict[str, Any]], model_manager=None, 
                openai_client=None, **kwargs) -> List[Dict[str, Any]]:
        """Run CoT perturbation evaluation on teacher responses."""
        checkpoint_path = kwargs.get('checkpoint_path', '')
        accuracy_metric = kwargs.get('accuracy_metric', None)
        
        if not accuracy_metric:
            raise ValueError("AccuracyMetric instance required for CoT perturbation evaluation")
        
        return self._run_cot_perturbation(checkpoint_path, teacher_responses, accuracy_metric, model_manager, openai_client)
    
    def _run_cot_perturbation(self, checkpoint_path: str, teacher_responses: List[Dict[str, Any]], 
                             accuracy_metric, model_manager=None, openai_client=None) -> List[Dict[str, Any]]:
        """Run CoT perturbation on existing teacher responses."""
        from ..utils.cot_perturbation_utils import perturb_teacher_responses
        
        logger.info("Running CoT perturbation...")
        
        # Load expert thinking if requested
        if "expert_thinking" in self.requested_cot_metrics:
            expert_thinking_config = self.config['evaluation']['expert_thinking_dir']

            if not isinstance(expert_thinking_config, dict):
                logger.error(f"expert_thinking_dir must be a dictionary mapping model_name -> path")
                raise ValueError("expert_thinking_dir must be a dictionary format")

            # Load expert thinking traces for each model
            for model_name, expert_path in expert_thinking_config.items():
                with open(expert_path, 'r') as f:
                    expert_thinking = json.load(f)
                    for idx, item in enumerate(expert_thinking):
                        if idx < len(teacher_responses):
                            # Store multiple expert traces with model name as key
                            if 'expert_thinking_traces' not in teacher_responses[idx]:
                                teacher_responses[idx]['expert_thinking_traces'] = {}
                            # Use teacher_thinking for the full reasoning trace
                            teacher_responses[idx]['expert_thinking_traces'][model_name] = item.get('teacher_thinking', item.get('teacher_answer', ''))
                            #print(item.get('teacher_thinking', ''))
                            #print(teacher_responses[idx]['expert_thinking_traces'][model_name])


        # Generate perturbed data (only standard perturbations use perturb_teacher_responses)
        perturbed_data = perturb_teacher_responses(teacher_responses, self.standard_perturbations)
        
        # Run perturbation inference for each perturbation type
        for pert_type, pert_responses in perturbed_data["perturbed"].items():

            logger.info(f"Running {pert_type} perturbation inference ({len(pert_responses)} samples)")
            own_thinking = pert_type not in ["expert_thinking"]

            if pert_type == "expert_thinking":
                logger.info("Using expert thinking")

                # For expert thinking, group responses by expert model
                expert_model_groups = {}
                for resp in pert_responses:
                    expert_model = resp.get('expert_model_name', 'default')
                    if expert_model not in expert_model_groups:
                        expert_model_groups[expert_model] = []
                    expert_model_groups[expert_model].append(resp)

                # Initialize expert_thinking_perturbation for all responses
                for original_resp in teacher_responses:
                    original_resp['expert_thinking_perturbation'] = {
                        'expert_perplexity': {},
                        'expert_token_logprobs': {},
                        'reward_scores': {},
                        'perturbed_outputs': {},
                        'perturbed_inputs': {}
                    }

                # Process each expert model separately
                for expert_model_name, expert_responses in expert_model_groups.items():
                    logger.info(f"Running expert thinking perturbation for {expert_model_name} ({len(expert_responses)} samples)")

                    # Generate responses using accuracy metric
                    expert_results = accuracy_metric.evaluate(
                        expert_responses,
                        model_manager=model_manager,
                        openai_client=openai_client,
                        own_thinking=False
                    )

                    # Match results back to original responses by index
                    for expert_resp, expert_result in zip(expert_responses, expert_results):
                        # Find the matching original response by index
                        original_index = expert_resp.get('index', -1)

                        # Find teacher response with matching index
                        original_resp = None
                        for teacher_resp in teacher_responses:
                            if teacher_resp.get('index') == original_index:
                                original_resp = teacher_resp
                                break

                        if original_resp is not None:
                            # Store expert-specific results - extract from k_responses
                            k_responses = expert_result.get('k_responses', [])
                            if k_responses:
                                # Use best response for expert thinking
                                best_k = max(k_responses, key=lambda x: x.get('reward_score', 0.0))

                                original_resp['expert_thinking_perturbation']['expert_perplexity'][expert_model_name] = best_k.get('generation_info', {}).get('expert_perplexity', 0.0)
                                original_resp['expert_thinking_perturbation']['expert_token_logprobs'][expert_model_name] = best_k.get('generation_info', {}).get('expert_token_logprobs', [])
                                original_resp['expert_thinking_perturbation']['reward_scores'][expert_model_name] = best_k.get('reward_score', 0.0)
                                original_resp['expert_thinking_perturbation']['perturbed_outputs'][expert_model_name] = best_k.get('teacher_response', '')
                                original_resp['expert_thinking_perturbation']['perturbed_inputs'][expert_model_name] = expert_result.get('question', '')

        # Handle special perturbations with custom logic
        if "cot_importance" in self.special_requested:
            logger.info("Running CoT importance evaluation")
            self._evaluate_cot_importance(teacher_responses, accuracy_metric, model_manager, openai_client)

        return teacher_responses

    def _js_bernoulli(self, p: float, q: float, eps: float = 1e-12) -> float:
        """
        Compute Jensen-Shannon divergence for Bernoulli distributions.

        Args:
            p: Probability for first distribution
            q: Probability for second distribution
            eps: Epsilon to avoid log(0)

        Returns:
            JS divergence value (symmetric, bounded [0, log(2)])
        """
        import math

        # Clip probabilities to avoid numerical issues
        p = min(max(p, eps), 1.0 - eps)
        q = min(max(q, eps), 1.0 - eps)

        # Compute mixture distribution
        m = 0.5 * (p + q)

        # KL(p || m)
        kl_p_m = p * math.log(p / m) + (1 - p) * math.log((1 - p) / (1 - m))

        # KL(q || m)
        kl_q_m = q * math.log(q / m) + (1 - q) * math.log((1 - q) / (1 - m))

        # JS divergence
        return 0.5 * kl_p_m + 0.5 * kl_q_m

    def _truncate_thinking(self, thinking: str, strategy: str, level: int) -> str:
        """
        Truncate thinking trace according to strategy.

        Args:
            thinking: Full thinking trace
            strategy: "word" for word-by-word, "percentage" for percentage-based
            level: Truncation level (word count for "word", percentage 0-100 for "percentage")

        Returns:
            Truncated thinking trace
        """
        if strategy == "word":
            words = thinking.split()
            return ' '.join(words[:level])
        elif strategy == "percentage":
            words = thinking.split()
            num_words = len(words)
            target_words = int(num_words * level / 100.0)
            return ' '.join(words[:target_words])
        else:
            raise ValueError(f"Unknown truncation strategy: {strategy}")

    def _evaluate_cot_importance(self, teacher_responses: List[Dict[str, Any]],
                                 accuracy_metric, model_manager=None, openai_client=None):
        """
        Evaluate CoT importance by measuring divergence between early-exit and full thinking.

        For each sample, computes:
        1. Sequence log-probability p(teacher_answer | full thinking)
        2. Sequence log-probability p(teacher_answer | truncated thinking) for each truncation level
        3. JS divergence between the two Bernoulli distributions

        NOTE: Uses the teacher's GENERATED answer (not gold answer) to measure how the model's
        confidence in its own answer changes with different amounts of thinking.

        Supports both word-by-word and percentage-based truncation strategies.

        OPTIMIZATION: Uses cached removed-answer thinking traces from informativeness
        evaluation to avoid redundant processing.

        Args:
            teacher_responses: List of teacher response dictionaries
            accuracy_metric: AccuracyMetric instance for generation
            model_manager: Model manager for generation
            openai_client: OpenAI client (if using OpenAI)
        """
        import math

        logger.info(f"Starting CoT importance evaluation for {len(teacher_responses)} samples")

        # Get truncation strategy from config (default: "both")
        truncation_strategy = self.config.get('evaluation', {}).get('cot_importance_strategy', 'both')
        logger.info(f"Using truncation strategy: {truncation_strategy}")

        # Initialize cot_importance_evaluation for all responses
        for resp in teacher_responses:
            resp['cot_importance_evaluation'] = {
                'js_divergences': [],
                'log_probs_full': [],
                'log_probs_truncated': [],
                'truncation_levels': [],
                'truncation_strategy': truncation_strategy,
                'num_words': 0
            }

        # BATCHED APPROACH: Prepare all eval items across all samples
        all_eval_items = []
        sample_metadata = []  # Track which items belong to which sample

        for idx, response in enumerate(teacher_responses):
            # OPTIMIZATION: Use cached thinking_without_answer from informativeness if available
            #if 'thinking_without_answer' in response and response['thinking_without_answer'].strip():
            #    teacher_thinking = response['thinking_without_answer']
            #    logger.debug(f"Sample {idx}: Using cached removed-answer thinking")
            if 'k_responses' in response and response['k_responses']:
                teacher_thinking = response['k_responses'][0].get('teacher_thinking', '')
                logger.debug(f"Sample {idx}: Using k_responses thinking (no cached version)")
            else:
                teacher_thinking = response.get('teacher_thinking', '')
                logger.debug(f"Sample {idx}: Using original thinking (no cached version)")

            if not teacher_thinking or not teacher_thinking.strip():
                logger.warning(f"Sample {idx}: No thinking found, skipping")
                continue

            # Get the question prompt and teacher's generated answer (not gold answer)
            full_prompt = response.get('full_prompt', response.get('question', ''))

            # Extract teacher's generated answer from k_responses
            if 'k_responses' in response and response['k_responses']:
                teacher_answer = response['k_responses'][0].get('teacher_answer', '')
            else:
                teacher_answer = response.get('teacher_answer', response.get('teacher_answer', ''))

            if not teacher_answer or not teacher_answer.strip():
                logger.warning(f"Sample {idx}: No teacher answer found, skipping")
                continue

            words = teacher_thinking.split()
            num_words = len(words)

            if num_words == 0:
                continue

            # Determine truncation levels based on strategy
            truncation_levels = []
            if truncation_strategy == "word" or truncation_strategy == "both":
                # Word-by-word: 1, 2, ..., num_words
                truncation_levels.extend([(level, "word") for level in range(1, num_words + 1)])
            if truncation_strategy == "percentage" or truncation_strategy == "both":
                # Percentage: 10%, 20%, ..., 100%
                truncation_levels.extend([(pct, "percentage") for pct in range(0, 101, 10)])

            # Create eval items for full thinking + all truncation levels
            sample_start_idx = len(all_eval_items)

            # Add full thinking item (used as reference)
            # Use teacher's answer to compute p(teacher_answer | question, full_thinking)
            prompt_without_answer_full = f"{full_prompt}{teacher_thinking}</think>\n<answer>"
            full_prompt_with_answer = f"{prompt_without_answer_full}{teacher_answer}</answer>"

            all_eval_items.append({
                'question': full_prompt_with_answer,
                'full_prompt': prompt_without_answer_full,
                'answer': teacher_answer,
                'index': response.get('index', idx),
                'truncation_level': 'full',
                'truncation_strategy': 'full',
                'metadata': response.get('metadata', {}),
            })

            # Add truncated thinking items
            for level, strategy in truncation_levels:
                truncated_thinking = self._truncate_thinking(teacher_thinking, strategy, level)
                prompt_without_answer = f"{full_prompt}{truncated_thinking}</think>\n<answer>"
                full_truncated_prompt = f"{prompt_without_answer}{teacher_answer}</answer>"

                all_eval_items.append({
                    'question': full_truncated_prompt,
                    'full_prompt': prompt_without_answer,
                    'answer': teacher_answer,
                    'index': response.get('index', idx),
                    'truncation_level': level,
                    'truncation_strategy': strategy,
                    'truncated_thinking_text': truncated_thinking,  # Store actual truncated text
                    'metadata': response.get('metadata', {}),
                })

            # Track metadata for this sample
            sample_metadata.append({
                'sample_idx': idx,
                'num_words': num_words,
                'truncation_levels': truncation_levels,
                'start_idx': sample_start_idx,
                'end_idx': len(all_eval_items)  # exclusive
            })

        if not all_eval_items:
            logger.warning("No eval items created for CoT importance")
            return

        # BATCH EVALUATE ALL ITEMS AT ONCE with vLLM prefix caching enabled
        logger.info(f"Evaluating {len(all_eval_items)} prompts across {len(sample_metadata)} samples in batch...")
        logger.info("NOTE: vLLM prefix caching should be enabled for optimal performance (5-10x speedup)")

        all_results = accuracy_metric.evaluate(
            all_eval_items,
            model_manager=model_manager,
            openai_client=openai_client,
            own_thinking=False  # Use prompt_logprobs for sequence log-prob calculation
        )

        # PROCESS RESULTS: Extract sequence log-probs and compute divergences
        for meta in sample_metadata:
            sample_idx = meta['sample_idx']
            start_idx = meta['start_idx']
            end_idx = meta['end_idx']
            num_words = meta['num_words']
            truncation_levels = meta['truncation_levels']

            sample_results = all_results[start_idx:end_idx]

            # First result is the full thinking
            full_result = sample_results[0]
            full_k_response = full_result['k_responses'][0]
            full_gen_info = full_k_response['generation_info']

            # Extract sequence log-prob for full thinking
            # We need to extract log-prob for the ANSWER tokens only
            full_prompt_logprobs = full_gen_info.get('expert_token_logprobs', [])

            # DEBUG: Show what we're working with

            
            #logger.info(f"\n{'='*80}")
            #logger.info(f"[DEBUG] Sample {sample_idx}: FULL THINKING EVALUATION")
            #logger.info(f"{'='*80}")
            full_question = response.get('question', 'N/A')
            #logger.info(f"Question: {full_question}...")
            #logger.info(f"Teacher answer: {teacher_answer}")
            #logger.info(f"Teacher thinking (length): {len(teacher_thinking.split())} words")
            ##logger.info(f"\n[DEBUG] Full input prompt sent to vLLM (last 500 chars):")
            full_eval_item = all_eval_items[start_idx]
            full_prompt_str = full_eval_item.get('question', 'N/A')
            #logger.info(f"...{full_prompt_str}")
            #logger.info(f"\nexpert_token_logprobs count: {len(full_prompt_logprobs)}")

            # Calculate full thinking sequence log-prob
            # The answer starts after the "<answer>" tag
            # We need to sum log-probs for answer tokens
            full_logprob_answer = 0.0
            if full_prompt_logprobs:
                # Sum negative log-probs for answer tokens
                answer_tokens_list = [token_info.get('selected_token', '?') for token_info in full_prompt_logprobs]
                #logger.info(f"\nSample {sample_idx}: Full thinking - processing {len(full_prompt_logprobs)} answer tokens")
                #logger.info(f"  Answer tokens: {answer_tokens_list}")

                for i, token_info in enumerate(full_prompt_logprobs):
                    neg_log_prob = token_info.get('neg_log_prob', 0.0)
                    log_prob = -neg_log_prob
                    full_logprob_answer += log_prob
                    #logger.info(f"  Token {i}: '{token_info.get('selected_token', '?')}' | log_prob: {log_prob:.6f} | cumulative: {full_logprob_answer:.6f}")

            # Convert to probability
            p_full = math.exp(full_logprob_answer)

            # Process truncated results
            js_divergences = []
            log_probs_full = []
            log_probs_truncated = []
            truncation_level_labels = []

            for result_idx, result in enumerate(sample_results[1:], start=1):
                k_response = result['k_responses'][0]
                gen_info = k_response['generation_info']

                # Extract sequence log-prob for truncated thinking
                truncated_prompt_logprobs = gen_info.get('expert_token_logprobs', [])

                # DEBUG: Show truncated thinking details for ALL levels
                if True:  # Print for all truncation levels
                    truncated_text = all_eval_items[start_idx + result_idx].get('truncated_thinking_text', 'N/A')
                    eval_item = all_eval_items[start_idx + result_idx]
                    full_prompt_for_eval = eval_item.get('question', 'N/A')

                    logger.info(f"\n{'='*80}")
                    logger.info(f"[DEBUG] Sample {sample_idx}: TRUNCATED THINKING #{result_idx}")
                    logger.info(f"{'='*80}")
                    logger.info(f"Truncated thinking text: {truncated_text}")
                    logger.info(f"\n[DEBUG] Full input prompt sent to vLLM (last 500 chars):")
                    logger.info(f"...{full_prompt_for_eval[-500:]}")
                    logger.info(f"\nexpert_token_logprobs count: {len(truncated_prompt_logprobs)}")

                    if truncated_prompt_logprobs:
                        logger.info(f"\n[DEBUG] First few token_info objects:")
                        for i, token_info in enumerate(truncated_prompt_logprobs):
                            logger.info(f"  [{i}] {token_info}")

                truncated_logprob_answer = 0.0
                if truncated_prompt_logprobs:
                    if result_idx == 1:  # Only log for first truncation level
                        answer_tokens_list = [token_info.get('selected_token', '?') for token_info in truncated_prompt_logprobs]
                        logger.info(f"\nSample {sample_idx}: Truncated thinking - processing {len(truncated_prompt_logprobs)} answer tokens")
                        logger.info(f"  Answer tokens: {answer_tokens_list}")

                    for i, token_info in enumerate(truncated_prompt_logprobs):
                        neg_log_prob = token_info.get('neg_log_prob', 0.0)
                        log_prob = -neg_log_prob
                        truncated_logprob_answer += log_prob
                        if result_idx == 1:  # Print all tokens for first truncation
                            logger.info(f"  Token {i}: '{token_info.get('selected_token', '?')}' | log_prob: {log_prob:.6f} | cumulative: {truncated_logprob_answer:.6f}")

                # Convert to probability
                p_truncated = math.exp(truncated_logprob_answer)
                print(f"Sample {sample_idx}: p(y|full) = {p_full:.6e}")
                print(f"Sample {sample_idx}: p(y|truncated) = {p_truncated:.6e}")

                # Compute JS divergence
                js_div = self._js_bernoulli(p_full, p_truncated)

                # Store results
                js_divergences.append(js_div)
                log_probs_full.append(full_logprob_answer)
                log_probs_truncated.append(truncated_logprob_answer)

                # Get actual truncated text instead of labels
                truncated_text = all_eval_items[start_idx + result_idx]['truncated_thinking_text']
                truncation_level_labels.append(truncated_text)

            # Compute final normalized JS for word-by-word strategy only
            results_dict = {
                'js_divergences': js_divergences,
                'log_probs_full': log_probs_full,
                'log_probs_truncated': log_probs_truncated,
                'truncation_levels': truncation_level_labels,  # Now contains actual truncated text
                'truncation_strategy': truncation_strategy,
                'num_words': num_words,
                'p_full': p_full,
                'full_logprob': full_logprob_answer
            }

            if truncation_strategy == "word" and js_divergences and num_words > 0:
                final_normalized_js = sum(js_divergences) / num_words
                results_dict['final_normalized_js'] = final_normalized_js
                logger.info(f"Sample {sample_idx}: final_normalized_js = {final_normalized_js:.6e} (sum={sum(js_divergences):.6e}, num_words={num_words})")

            # Store results
            teacher_responses[sample_idx]['cot_importance_evaluation'] = results_dict

            logger.info(f"Sample {sample_idx}: p(y|full) = {p_full:.6e}, JS divergences = {js_divergences[:3]}{'...' if len(js_divergences) > 3 else ''}")

        logger.info(f"CoT importance evaluation completed: {len(all_eval_items)} prompts evaluated")

    def process_teacher_responses(self, teacher_responses_path: str,
                                teacher_data: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Process teacher responses and calculate perturbation reward metrics.
        This is called AFTER the perturbation inference has been completed.

        Args:
            teacher_responses_path: Path where original responses are saved
            teacher_data: Teacher responses with perturbation results

        Returns:
            Dictionary with perturbation metrics
        """
        try:
            # Calculate original reward using average reward
            reward_scores = [item.get('reward_score', 0.0) for item in teacher_data]
            original_reward = sum(reward_scores) / len(reward_scores) if reward_scores else 0.0

            # Calculate expert thinking rewards from perturbation results
            expert_thinking_scores_by_model = {}

            for item in teacher_data:
                if 'expert_thinking_perturbation' in item:
                    expert_rewards = item['expert_thinking_perturbation'].get('reward_scores', {})
                    for expert_model, reward in expert_rewards.items():
                        if expert_model not in expert_thinking_scores_by_model:
                            expert_thinking_scores_by_model[expert_model] = []
                        expert_thinking_scores_by_model[expert_model].append(reward)

            # Calculate expert thinking rewards for each model
            expert_thinking_rewards = {}
            for expert_model, scores in expert_thinking_scores_by_model.items():
                expert_thinking_rewards[expert_model] = sum(scores) / len(scores) if scores else 0.0

            # Prepare results
            results = {
                'original_reward': original_reward,
                'expert_thinking_rewards': expert_thinking_rewards,  # Dictionary with expert model names as keys
                'num_samples': len(teacher_data),
                'perturbation_details': {
                    'expert_thinking_samples_by_model': {model: len(scores) for model, scores in expert_thinking_scores_by_model.items()},
                    'expert_thinking_improvements': {model: reward - original_reward for model, reward in expert_thinking_rewards.items()}
                }
            }

            logger.info(f"CoT Perturbation metrics: {len(teacher_data)} samples")
            logger.info(f"  Original reward: {original_reward:.4f}")

            # Log expert thinking rewards for each model
            for expert_model, expert_reward in expert_thinking_rewards.items():
                logger.info(f"  Expert thinking reward ({expert_model}): {expert_reward:.4f} (improvement: {expert_reward - original_reward:.4f})")

            return results

        except Exception as e:
            logger.error(f"Error processing CoT perturbation metrics: {e}")
            return {
                'original_reward': 0.0,
                'expert_thinking_rewards': {},
                'num_samples': 0,
                'error': str(e)
            }

    def add_checkpoint_result(self, checkpoint_name: str, results: Dict[str, Any]) -> None:
        """
        Add results for a specific checkpoint.
        
        Args:
            checkpoint_name: Name of the checkpoint
            results: Results dictionary from evaluation
        """
        self.checkpoint_results[checkpoint_name] = {
            **results,
            'timestamp': datetime.now().isoformat()
        }
        
        logger.info(f"Added usefulness results for checkpoint {checkpoint_name}")
    
    def get_checkpoint_result(self, checkpoint_name: str) -> Optional[Dict[str, Any]]:
        """
        Get results for a specific checkpoint.
        
        Args:
            checkpoint_name: Name of the checkpoint
            
        Returns:
            Results dictionary or None if not found
        """
        return self.checkpoint_results.get(checkpoint_name)
    
    def get_all_results(self) -> Dict[str, Dict[str, Any]]:
        """
        Get results for all checkpoints.
        
        Returns:
            Dictionary mapping checkpoint names to results
        """
        return self.checkpoint_results
    
    def get_sorted_results(self) -> List[Tuple[str, Dict[str, Any]]]:
        """
        Get checkpoint results sorted by checkpoint name/step.
        
        Returns:
            List of (checkpoint_name, results) tuples sorted by step
        """
        def extract_step(checkpoint_name):
            try:
                # Try to extract step number from checkpoint name
                if 'step_' in checkpoint_name:
                    return int(checkpoint_name.split('step_')[1].split('_')[0])
                elif 'epoch_' in checkpoint_name:
                    return int(checkpoint_name.split('epoch_')[1].split('_')[0])
                else:
                    return 0
            except:
                return 0
        
        results = [(name, self.checkpoint_results[name]) 
                  for name in self.checkpoint_results.keys()]
        
        results.sort(key=lambda x: extract_step(x[0]))
        return results
    
    def calculate_summary_statistics(self) -> Dict[str, Any]:
        """Calculate summary statistics across all checkpoints."""
        if not self.checkpoint_results:
            return {}

        # Extract rewards
        original_rewards = [r.get('original_reward', 0.0) for r in self.checkpoint_results.values()]

        # Extract expert thinking rewards by model
        expert_thinking_rewards_by_model = {}
        for r in self.checkpoint_results.values():
            expert_rewards = r.get('expert_thinking_rewards', {})
            for model_name, reward in expert_rewards.items():
                if model_name not in expert_thinking_rewards_by_model:
                    expert_thinking_rewards_by_model[model_name] = []
                expert_thinking_rewards_by_model[model_name].append(reward)

        # Calculate expert thinking statistics for each model
        expert_thinking_stats = {}
        for model_name, rewards in expert_thinking_rewards_by_model.items():
            if rewards:
                expert_thinking_stats[model_name] = {
                    'mean': float(np.mean(rewards)),
                    'std': float(np.std(rewards)),
                    'min': float(np.min(rewards)),
                    'max': float(np.max(rewards))
                }

        return {
            'num_checkpoints': len(self.checkpoint_results),
            'original_reward': {
                'mean': float(np.mean(original_rewards)),
                'std': float(np.std(original_rewards)),
                'min': float(np.min(original_rewards)),
                'max': float(np.max(original_rewards))
            },
            'expert_thinking_rewards': expert_thinking_stats
        }

    def analyze_performance_trends(self) -> Dict[str, Any]:
        """Analyze performance trends across checkpoints."""
        if len(self.checkpoint_results) < 2:
            return {'error': 'Need at least 2 checkpoints for trend analysis'}

        sorted_results = self.get_sorted_results()

        # Extract step numbers and rewards
        steps = []
        original_rewards = []
        expert_thinking_rewards_by_model = {}

        for checkpoint_name, results in sorted_results:
            try:
                if 'step_' in checkpoint_name:
                    step = int(checkpoint_name.split('step_')[1].split('_')[0])
                else:
                    step = 0
            except:
                step = 0

            steps.append(step)
            original_rewards.append(results.get('original_reward', 0.0))

            # Handle multiple expert thinking models
            expert_rewards = results.get('expert_thinking_rewards', {})
            for model_name, reward in expert_rewards.items():
                if model_name not in expert_thinking_rewards_by_model:
                    expert_thinking_rewards_by_model[model_name] = []
                expert_thinking_rewards_by_model[model_name].append(reward)

        return {
            'steps': steps,
            'original_rewards': original_rewards,
            'expert_thinking_rewards': expert_thinking_rewards_by_model
        }
    
    def get_detailed_analysis(self) -> Dict[str, Any]:
        """
        Get comprehensive analysis of all results.
        
        Returns:
            Dictionary with detailed analysis
        """
        analysis = {
            'metric_name': 'usefulness',
            'summary_statistics': self.calculate_summary_statistics(),
            'performance_trends': self.analyze_performance_trends(),
            'checkpoint_results': self.checkpoint_results,
            'sorted_results': self.get_sorted_results(),
            'generated_at': datetime.now().isoformat()
        }
        
        return analysis
    
    def save_metrics(self, output_path: str) -> None:
        """
        Save metrics to file.
        
        Args:
            output_path: Path to save metrics
        """
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        analysis = self.get_detailed_analysis()
        
        with open(output_path, 'w') as f:
            json.dump(analysis, f, indent=2)
        
        logger.info(f"Saved usefulness metrics to {output_path}")
    
    def load_metrics(self, input_path: str) -> None:
        """
        Load metrics from file.
        
        Args:
            input_path: Path to load metrics from
        """
        with open(input_path, 'r') as f:
            data = json.load(f)
        
        if 'checkpoint_results' in data:
            self.checkpoint_results = data['checkpoint_results']
            logger.info(f"Loaded usefulness metrics for {len(self.checkpoint_results)} checkpoints")
        else:
            logger.warning("No checkpoint results found in loaded data")
    
    def to_dict(self) -> Dict[str, Any]:
        """
        Convert metrics to dictionary format.
        
        Returns:
            Dictionary representation of metrics
        """
        return self.get_detailed_analysis()
    
    def print_summary(self) -> None:
        """Print a summary of the CoT perturbation metrics."""
        print("\n=== USEFULNESS EVALUATION SUMMARY ===")

        if not self.checkpoint_results:
            print("No checkpoint results available.")
            return

        summary = self.calculate_summary_statistics()

        print(f"Checkpoints evaluated: {summary['num_checkpoints']}")
        print(f"Original Reward: {summary['original_reward']['mean']:.4f} ± {summary['original_reward']['std']:.4f}")

        # Show expert thinking rewards for each model
        expert_rewards = summary.get('expert_thinking_rewards', {})
        for model_name, stats in expert_rewards.items():
            print(f"Expert Thinking Reward ({model_name}): {stats['mean']:.4f} ± {stats['std']:.4f}")

        # Show expert thinking improvements for each model
        orig_mean = summary['original_reward']['mean']
        for model_name, stats in expert_rewards.items():
            expert_mean = stats['mean']
            print(f"Expert Thinking Improvement ({model_name}): {expert_mean - orig_mean:.4f}")

        print("=" * 50)