"""
CoT Verifier Accuracy Metric

Evaluates how well reasoning traces enable a model to extract answers
by comparing accuracy with vs without the original question.
"""

import logging
import numpy as np
import asyncio
import os
import hashlib
import time
import random
from typing import Dict, List, Any

from .base_metric import BaseMetric
from ..utils.openai_client import OpenAIClient
from ..utils.vllm_model_manager import VLLMModelManager
from ..utils.answer_removal import process_batch_remove_answers
from data.utils import extract_answer
from vllm import SamplingParams

logger = logging.getLogger(__name__)


class CotVerifierAccuracyMetric(BaseMetric):
    """Handles CoT verifier accuracy evaluation of teacher reasoning traces."""

    def _init_metric_config(self):
        """Initialize cot_verifier_accuracy metric configuration."""
        requested_metrics = self.config.get('evaluation', {}).get('metrics', [])
        # Only support cot_verifier_accuracy
        self.cot_verifier_types = ['cot_verifier_accuracy']
        self.requested_cot_verifier_metrics = [r for r in requested_metrics if r in self.cot_verifier_types]

        # Initialize student model manager (will be set up later)
        self.student_model_manager = None
        self.student_openai_client = None
        self.use_local_student_model = False
    
    def can_run(self) -> bool:
        """Check if cot_verifier_accuracy evaluation should be run."""
        return len(self.requested_cot_verifier_metrics) > 0

    def _is_local_model(self, model_path: str) -> bool:
        """Check if model path is a local path or an API model name."""
        # If it starts with /, it's a local path
        if model_path.startswith('/'):
            return True
        # If it contains known API model prefixes, it's an API model
        api_prefixes = ['gpt-', 'o1-', 'o3-', 'claude-', 'gpt-oss-', 'qwen3-30b-a3b-thinking']
        return not any(model_path.startswith(prefix) for prefix in api_prefixes)

    def evaluate(self, teacher_responses: List[Dict[str, Any]], model_manager=None,
                openai_client=None, **kwargs) -> List[Dict[str, Any]]:
        """Run cot_verifier_accuracy evaluation on teacher responses."""
        student_data = kwargs.get('student_data', None)

        #if not student_data:
        #    raise ValueError("Student data is required for cot_verifier_accuracy evaluation")

        # Pass teacher model_manager so we can unload it before loading student model
        return self._run_cot_verifier_evaluation(teacher_responses, student_data, teacher_model_manager=model_manager)
    
    def _run_cot_verifier_evaluation(self, teacher_responses: List[Dict[str, Any]],
                                      student_data: List[Dict[str, Any]], teacher_model_manager=None) -> List[Dict[str, Any]]:
        """Run cot_verifier_accuracy evaluation on teacher responses."""
        # Run async evaluation
        loop = asyncio.new_event_loop()
        try:
            result = loop.run_until_complete(
                self._run_cot_verifier_evaluation_async(teacher_responses, student_data, teacher_model_manager)
            )
            return result
        finally:
            loop.close()

    async def _run_cot_verifier_evaluation_async(self, teacher_responses: List[Dict[str, Any]],
                                                   student_data: List[Dict[str, Any]], teacher_model_manager=None) -> List[Dict[str, Any]]:
        """Async implementation of cot_verifier_accuracy evaluation with concurrent processing.

        Args:
            teacher_responses: List of teacher response dictionaries
            student_data: List of student data dictionaries
            teacher_model_manager: Teacher VLLMModelManager to unload before loading student model
        """
        logger.info("Running cot_verifier_accuracy evaluation...")

        # Step 1: Batch remove answers asynchronously using LLM
        # Only cot_verifier_accuracy needs answer removal
        needs_answer_removal = 'cot_verifier_accuracy' in self.requested_cot_verifier_metrics

        thinking_without_answer_map = {}
        if needs_answer_removal:
            logger.info("Batch removing answers from teacher responses using LLM (GPT-4o-mini) asynchronously...")
            # Prepare items for batch processing
            items_for_removal = []
            for teacher_idx, response in enumerate(teacher_responses):
                # Extract teacher_thinking from k_responses (use best scoring response)
                k_responses = response.get('k_responses', [])
                if k_responses:
                    # Find the k response with the highest reward score
                    best_k_response = max(k_responses, key=lambda x: x.get('reward_score', 0))
                    teacher_thinking = best_k_response.get('teacher_thinking', '')
                    teacher_answer = best_k_response.get('teacher_answer', '')
                    teacher_response_text = best_k_response.get('teacher_response', 'placeholder')
                else:
                    # Fallback for old format
                    teacher_thinking = response.get('teacher_thinking', '')
                    teacher_answer = response.get('teacher_answer', '')
                    teacher_response_text = response.get('teacher_response', 'placeholder')

                items_for_removal.append({
                    'teacher_thinking': teacher_thinking,
                    'teacher_answer': teacher_answer,
                    'teacher_response': teacher_response_text,
                    'teacher_idx': teacher_idx  # Track original index
                })

            # Get OpenAI API key
            api_key = os.getenv("OPENAI_API_KEY")
            if not api_key:
                logger.warning("OPENAI_API_KEY not found, skipping answer removal")
                # Use original thinking if no API key
                for teacher_idx, response in enumerate(teacher_responses):
                    thinking_without_answer_map[teacher_idx] = response['teacher_thinking']
            else:
                # Async batch remove with controlled concurrency (default 50 concurrent)
                cleaned_items = await process_batch_remove_answers(items_for_removal, api_key, max_concurrent=50)

                # Build map of teacher_idx -> cleaned thinking
                for item in cleaned_items:
                    teacher_idx = item.get('teacher_idx', 0)
                    cleaned_thinking = item.get('teacher_thinking_without_answer', '')
                    thinking_without_answer_map[teacher_idx] = cleaned_thinking

                    # IMPORTANT: Save removed-answer thinking in teacher_responses for reuse by incremental_thinking
                    teacher_responses[teacher_idx]['thinking_without_answer'] = cleaned_thinking

                logger.info(f"Batch removed answers from {len(cleaned_items)} teacher responses")

        # Step 2: Add convenience fields to teacher responses for easier access
        for response in teacher_responses:
            k_responses = response.get('k_responses', [])
            if k_responses:
                # Find the k response with the highest reward score
                best_k_response = max(k_responses, key=lambda x: x.get('reward_score', 0))
                response['teacher_thinking'] = best_k_response.get('teacher_thinking', '')
                response['teacher_answer'] = best_k_response.get('teacher_answer', '')
                response['reward_score'] = best_k_response.get('reward_score', 0.0)
            # Ensure metadata and gold_answer are accessible
            if 'metadata' not in response:
                response['metadata'] = {}
            if 'input_str' not in response['metadata'] and 'question' in response:
                response['metadata']['input_str'] = response['question']
            if 'source_dataset' not in response['metadata'] and 'data_source' in response:
                response['metadata']['source_dataset'] = response['data_source']
            response['gold_answer'] = response.get('gold_answer', '')

        # Step 3: Initialize student model (before baseline calculation)
        student_model_path = self.config['evaluation']['student_model']['model_path']

        # Determine if we're using local vLLM model or API model
        self.use_local_student_model = self._is_local_model(student_model_path)

        logger.info(f"\n{'='*80}")
        logger.info(f"STUDENT MODEL INITIALIZATION FOR COT_VERIFIER_ACCURACY METRICS")
        logger.info(f"Student model path: {student_model_path}")
        logger.info(f"Model type: {'Local vLLM' if self.use_local_student_model else 'OpenAI API'}")
        logger.info(f"Requested metrics: {self.requested_cot_verifier_metrics}")
        logger.info(f"{'='*80}\n")

        if self.use_local_student_model:
            # Unload teacher model to free up GPU memory before loading student model
            if teacher_model_manager is not None and hasattr(teacher_model_manager, 'current_model') and teacher_model_manager.current_model is not None:
                logger.info("Unloading teacher model to free GPU memory before loading student model...")
                try:
                    # Unload the teacher model
                    del teacher_model_manager.current_model
                    teacher_model_manager.current_model = None

                    # Force garbage collection to free GPU memory
                    import gc
                    import torch
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                    logger.info("Successfully unloaded teacher model and freed GPU memory")
                except Exception as e:
                    logger.warning(f"Error while unloading teacher model: {e}")

            logger.info(f"Loading local vLLM model for student evaluation: {student_model_path}")

            # Generate unique hash for student model to prevent conflicts with parallel jobs
            # Hash is based on model path, timestamp, PID, and random component
            unique_str = f"student_{student_model_path}_{time.time()}_{os.getpid()}_{random.randint(0, 999999)}"
            student_unique_hash = hashlib.md5(unique_str.encode()).hexdigest()[:8]
            logger.info(f"Generated unique hash for student model: {student_unique_hash}")

            # Set up local vLLM model manager
            # For student model, we load directly without checkpoint loader since it's a base model path
            vllm_config = self.config.get('evaluation', {}).get('vllm', {})
            self.student_model_manager = VLLMModelManager(vllm_config, checkpoint_loader=None, unique_hash=student_unique_hash)
            # Load model directly (base model path, not checkpoint)
            self.student_model_manager._load_vllm_model(student_model_path)
            self.student_model_manager.current_checkpoint_name = f"student_{student_model_path.split('/')[-1]}"
        else:
            logger.info(f"Using OpenAI API model for student evaluation: {student_model_path}")
            client = OpenAIClient(student_model_path)
            # Store client for cost tracking
            self.student_openai_client = client

        # Step 4: Calculate student baseline accuracy (after model is initialized)
        # Disabled: No need to run student_response_baseline
        # student_baseline_accuracy, student_baseline_detailed_results = await self._calculate_student_baseline_accuracy(student_data)
        # logger.info(f"Student baseline accuracy: {student_baseline_accuracy:.4f}")
        student_baseline_accuracy = 0.0
        student_baseline_detailed_results = []

        cot_verifier_prompts = []
        prompt_metadata = []  # Track which teacher/student pair each prompt corresponds to

        for teacher_idx, response in enumerate(teacher_responses):
            # Only handle cot_verifier_accuracy (no student data loop needed)
            if 'cot_verifier_accuracy' in self.requested_cot_verifier_metrics:
                # Use pre-cleaned thinking from batch removal
                thinking_without_answer = thinking_without_answer_map.get(
                    teacher_idx,
                    response['teacher_thinking']  # Fallback to original if not found
                )

                # Create TWO prompts per instance:
                # 1. answer | question, thinking (WITH question)
                prompt_with_question = self.prompt_manager.create_verifier_prompt_with_question(
                    response['metadata']['input_str'],
                    thinking_without_answer,
                    task = response['metadata']['source_dataset']
                )

                # 2. answer | thinking (WITHOUT question)
                prompt_without_question = self.prompt_manager.create_verifier_prompt_without_question(
                    thinking_without_answer,
                    task = response['metadata']['source_dataset']
                )

                # Add both prompts
                cot_verifier_prompts.append(prompt_with_question)
                prompt_metadata.append({
                    'teacher_idx': teacher_idx,
                    'student_idx': 0,
                    'teacher_answer': response['teacher_answer'],
                    'cot_verifier_type': 'cot_verifier_accuracy',
                    'prompt_variant': 'with_question'
                })

                cot_verifier_prompts.append(prompt_without_question)
                prompt_metadata.append({
                    'teacher_idx': teacher_idx,
                    'student_idx': 0,
                    'teacher_answer': response['teacher_answer'],
                    'cot_verifier_type': 'cot_verifier_accuracy',
                    'prompt_variant': 'without_question'
                })

        # Log prompt count
        logger.info(f"DEBUG: Number of cot_verifier prompts: {len(cot_verifier_prompts)}")

        # Step 5: Generate responses (student model already initialized in Step 3)
        cot_verifier_responses = []

        if self.use_local_student_model:
            # Generate with local vLLM model (synchronous batching)
            logger.info(f"Processing {len(cot_verifier_prompts)} prompts with local vLLM model")
            cot_verifier_responses = await self._generate_with_local_model(cot_verifier_prompts)
        else:
            # Generate with OpenAI API (async with controlled parallelism)
            max_concurrent_api_calls = self.config.get('performance', {}).get('max_concurrent_api_calls', 30)
            num_batches = (len(cot_verifier_prompts) + self.batch_size - 1) // self.batch_size
            logger.info(f"Processing {len(cot_verifier_prompts)} prompts in {num_batches} batches with max {max_concurrent_api_calls} concurrent API calls")

            # Create semaphore to limit concurrent requests
            semaphore = asyncio.Semaphore(max_concurrent_api_calls)

            async def process_batch(batch_idx, batch_prompts):
                """Process a single batch with semaphore control."""
                async with semaphore:
                    logger.info(f"Processing cot_verifier batch {batch_idx + 1}/{num_batches}")
                    try:
                        batch_responses = await client.generate_individual_async(
                            batch_prompts,
                            temperature=self.teacher_config.get('temperature', 0.7),
                            max_tokens=self.teacher_config.get('max_tokens', 512),
                            model=self.teacher_config.get('model', 'gpt-4')
                        )
                        logger.info(f"DEBUG: Received {len(batch_responses)} responses for batch {batch_idx + 1}")
                        return batch_responses
                    except Exception as e:
                        logger.error(f"DEBUG: Error in batch {batch_idx + 1}: {e}")
                        return [{'text': ''} for _ in batch_prompts]

            # Create all batch tasks
            batch_tasks = []
            for i in range(0, len(cot_verifier_prompts), self.batch_size):
                batch_prompts = cot_verifier_prompts[i:i + self.batch_size]
                batch_idx = i // self.batch_size
                batch_tasks.append(process_batch(batch_idx, batch_prompts))

            # Execute all batches concurrently and gather results
            logger.info(f"Launching {len(batch_tasks)} concurrent batch tasks...")
            all_batch_responses = await asyncio.gather(*batch_tasks)

            # Flatten results
            for batch_responses in all_batch_responses:
                cot_verifier_responses.extend(batch_responses)

        logger.info(f"DEBUG: Total cot_verifier responses: {len(cot_verifier_responses)}")

        # Debug first few responses
        for i, response in enumerate(cot_verifier_responses[:3]):
            logger.info(f"DEBUG: Response {i}: {response}")

        # Calculate cot_verifier_accuracy scores using reward calculator
        # Group responses by teacher and cot_verifier type
        teacher_scores = {}
        detailed_results = []  # Store detailed results for saving

        # Track answers from both variants (with_question and without_question)
        verifier_results = {}  # teacher_idx -> {with_question: {answer, score}, without_question: {answer, score}}

        for i, (verifier_response, metadata) in enumerate(zip(cot_verifier_responses, prompt_metadata)):
            teacher_idx = metadata['teacher_idx']
            student_idx = metadata['student_idx']
            teacher_answer = metadata['teacher_answer']
            verifier_type = metadata['cot_verifier_type']
            prompt_variant = metadata.get('prompt_variant', 'with_question')  # Default to with_question for backward compatibility

            # Extract answer - only handling cot_verifier_accuracy now
            raw_response = verifier_response.get('text', '')
            answer = (extract_answer(raw_response, 'final answer') or
                     extract_answer(raw_response, 'answer') or
                     raw_response)

            # Normalize extracted_answer and teacher_answer for comparison (remove whitespace and newlines)
            extracted_answer_normalized = answer.replace('\n', '').replace(' ', '').strip()
            teacher_answer_normalized = teacher_responses[teacher_idx].get('teacher_answer', '').replace('\n', '').replace(' ', '').strip()
            answer_matches_teacher = (extracted_answer_normalized == teacher_answer_normalized)

            # For cot_verifier_accuracy, use binary exact matching
            score = 1.0 if answer_matches_teacher else 0.0

            # Store results for each variant
            if teacher_idx not in verifier_results:
                verifier_results[teacher_idx] = {}

            verifier_results[teacher_idx][prompt_variant] = {
                'answer': answer,
                'answer_normalized': extracted_answer_normalized,
                'score': score,
                'raw_response': raw_response
            }

            # Store detailed result including full input prompt
            detailed_result = {
                'teacher_idx': teacher_idx,
                'student_idx': student_idx,
                'cot_verifier_type': verifier_type,
                'prompt_variant': prompt_variant,
                #'teacher_input': teacher_responses[teacher_idx]['metadata']['input_str'],
                'teacher_thinking': teacher_responses[teacher_idx]['teacher_thinking'],
                'teacher_response': teacher_responses[teacher_idx].get('k_responses', [{}])[0].get('teacher_response', '') if teacher_responses[teacher_idx].get('k_responses') else '',
                #'student_input': student_data[student_idx]['metadata']['input_str'],
                'teacher_answer': teacher_answer,
                'full_input_prompt': cot_verifier_prompts[i],
                'cot_verifier_response': verifier_response.get('text', ''),
                'extracted_answer': answer,
                'score': score,
                'teacher_original_score': teacher_responses[teacher_idx]['reward_score'],
                'answer_matches_teacher': answer_matches_teacher
            }
            detailed_results.append(detailed_result)

            # Accumulate scores for each teacher by type
            if teacher_idx not in teacher_scores:
                teacher_scores[teacher_idx] = {}
            if verifier_type not in teacher_scores[teacher_idx]:
                teacher_scores[teacher_idx][verifier_type] = []
            teacher_scores[teacher_idx][verifier_type].append(score)
        
        # Calculate verifier accuracy comparison metrics
        # Now compute stats comparing with_question vs without_question
        verifier_comparison_stats = {
            'total_instances': 0,
            'with_question_correct': 0,
            'without_question_correct': 0,
            'both_correct': 0,
            'both_incorrect': 0,
            'answers_match': 0,  # Number of instances where both variants extracted the same answer
            'only_with_question_correct': 0,
            'only_without_question_correct': 0
        }

        for teacher_idx, variants in verifier_results.items():
            # Only process if we have both variants
            if 'with_question' in variants and 'without_question' in variants:
                verifier_comparison_stats['total_instances'] += 1

                with_q = variants['with_question']
                without_q = variants['without_question']

                # Track correctness
                if with_q['score'] == 1.0:
                    verifier_comparison_stats['with_question_correct'] += 1
                if without_q['score'] == 1.0:
                    verifier_comparison_stats['without_question_correct'] += 1

                # Track both correct/incorrect
                if with_q['score'] == 1.0 and without_q['score'] == 1.0:
                    verifier_comparison_stats['both_correct'] += 1
                elif with_q['score'] == 0.0 and without_q['score'] == 0.0:
                    verifier_comparison_stats['both_incorrect'] += 1

                # Track if answers match between variants
                if with_q['answer_normalized'] == without_q['answer_normalized']:
                    verifier_comparison_stats['answers_match'] += 1

                # Track exclusive correctness
                if with_q['score'] == 1.0 and without_q['score'] == 0.0:
                    verifier_comparison_stats['only_with_question_correct'] += 1
                elif with_q['score'] == 0.0 and without_q['score'] == 1.0:
                    verifier_comparison_stats['only_without_question_correct'] += 1

                # Store comparison data in teacher_responses
                teacher_responses[teacher_idx]['verifier_comparison'] = {
                    'with_question_answer': with_q['answer'],
                    'without_question_answer': without_q['answer'],
                    'with_question_score': with_q['score'],
                    'without_question_score': without_q['score'],
                    'answers_match': (with_q['answer_normalized'] == without_q['answer_normalized']),
                    'both_correct': (with_q['score'] == 1.0 and without_q['score'] == 1.0),
                    'both_incorrect': (with_q['score'] == 0.0 and without_q['score'] == 0.0)
                }

        # Calculate accuracy rates
        total = verifier_comparison_stats['total_instances']
        if total > 0:
            verifier_comparison_stats['with_question_accuracy'] = verifier_comparison_stats['with_question_correct'] / total
            verifier_comparison_stats['without_question_accuracy'] = verifier_comparison_stats['without_question_correct'] / total
            verifier_comparison_stats['accuracy_difference'] = (
                verifier_comparison_stats['with_question_accuracy'] - verifier_comparison_stats['without_question_accuracy']
            )
            verifier_comparison_stats['answer_match_rate'] = verifier_comparison_stats['answers_match'] / total
            verifier_comparison_stats['both_correct_rate'] = verifier_comparison_stats['both_correct'] / total
            verifier_comparison_stats['both_incorrect_rate'] = verifier_comparison_stats['both_incorrect'] / total

        # Calculate average cot_verifier_accuracy score for each teacher response by type
        for teacher_idx, type_scores in teacher_scores.items():
            for verifier_type, scores in type_scores.items():
                avg_score = sum(scores) / len(scores) if scores else 0.0
                teacher_responses[teacher_idx][f'{verifier_type}_score'] = avg_score

            # Store detailed results in teacher response for later saving
            teacher_responses[teacher_idx]['cot_verifier_detailed_results'] = [
                result for result in detailed_results if result['teacher_idx'] == teacher_idx
            ]

            # Store student baseline accuracy for comparison
            teacher_responses[teacher_idx]['student_accuracy_baseline'] = student_baseline_accuracy

            # Store student baseline detailed results for later saving (only in first teacher response to avoid duplication)
            if teacher_idx == 0:
                teacher_responses[teacher_idx]['student_baseline_detailed_results'] = student_baseline_detailed_results

            # Store verifier comparison stats (only in first response to avoid duplication)
            if teacher_idx == 0:
                teacher_responses[teacher_idx]['verifier_comparison_stats'] = verifier_comparison_stats

        return teacher_responses
    
    async def _generate_with_local_model(self, prompts: List[str]) -> List[Dict[str, Any]]:
        """Generate responses using local vLLM model."""
        # Create sampling params
        sampling_params = SamplingParams(
            temperature=self.teacher_config.get('temperature', 0.7),
            max_tokens=self.teacher_config.get('max_tokens', 512),
            top_p=self.teacher_config.get('top_p', 1.0),
            top_k=self.teacher_config.get('top_k', -1)
        )

        all_responses = []

        # Process in batches
        for i in range(0, len(prompts), self.batch_size):
            batch_prompts = prompts[i:i + self.batch_size]
            logger.info(f"Processing local vLLM batch {i//self.batch_size + 1}/{(len(prompts) + self.batch_size - 1)//self.batch_size}")

            try:
                # Generate with vLLM
                outputs = self.student_model_manager.current_model.generate(
                    batch_prompts,
                    sampling_params,
                    use_tqdm=False
                )

                # Convert vLLM outputs to OpenAI-style format
                for output in outputs:
                    response_text = output.outputs[0].text if output.outputs else ""
                    all_responses.append({'text': response_text})

            except Exception as e:
                logger.error(f"Error in local vLLM batch {i//self.batch_size + 1}: {e}")
                # Add empty responses for failed batch
                all_responses.extend([{'text': ''} for _ in batch_prompts])

        return all_responses

    async def _calculate_student_baseline_accuracy(self, student_data: List[Dict[str, Any]]) -> tuple[float, List[Dict[str, Any]]]:
        """Calculate baseline accuracy of student model on the student dataset and return detailed responses."""
        logger.info("Calculating student baseline accuracy...")

        # For teacher_student_accuracy, we only need 1 sample for baseline
        if student_data and len(student_data) > 1:
            if 'teacher_student_accuracy' in self.requested_informativeness_metrics and \
               len(self.requested_informativeness_metrics) == 1:
                student_data = student_data[:1]
                logger.info("Using only 1 student sample for baseline (teacher_student_accuracy mode)")

        # Create prompts for student model using the same developer prompt structure as teacher
        student_prompts = []
        for student_item in student_data:
            # Use the same prompt structure as teacher - just the question with developer prompt
            question = student_item['question']
            prompt = self.prompt_manager.create_openai_prompt(question)
            student_prompts.append(prompt)

        # Generate responses using the already-initialized student model (local or API)
        # NOTE: This is called AFTER the student model has been set up in the main async function
        student_responses = []

        if self.use_local_student_model:
            # Use local vLLM model
            logger.info(f"Calculating baseline with local vLLM model")
            student_responses = await self._generate_with_local_model(student_prompts)
        else:
            # Use OpenAI API client
            logger.info(f"Calculating baseline with OpenAI API model")
            for i in range(0, len(student_prompts), self.batch_size):
                batch_prompts = student_prompts[i:i + self.batch_size]

                logger.info(f"Processing student baseline batch {i//self.batch_size + 1}/{(len(student_prompts) + self.batch_size - 1)//self.batch_size}")

                try:
                    batch_responses = await self.student_openai_client.generate_individual_async(
                        batch_prompts,
                        temperature=self.teacher_config.get('temperature', 0.7),
                        max_tokens=self.teacher_config.get('max_tokens', 512),
                        model=self.teacher_config.get('model', 'gpt-4')
                    )
                    student_responses.extend(batch_responses)

                except Exception as e:
                    logger.error(f"Error in student baseline batch {i//self.batch_size + 1}: {e}")
                    # Add empty responses for failed batch
                    student_responses.extend([{'text': ''} for _ in batch_prompts])
        
        # Calculate accuracy and collect detailed results
        correct_count = 0
        total_count = len(student_data)
        detailed_baseline_results = []
        
        for i, (student_response, student_item) in enumerate(zip(student_responses, student_data)):
            # Extract answer from student response
            answer = (extract_answer(student_response.get('text', ''), 'final answer') or 
                     extract_answer(student_response.get('text', ''), 'answer') or 
                     student_response.get('text', ''))
            
            # Calculate score using the same scoring method as informativeness
            entry = {
                'answer': student_item['answer'],
                'metadata': student_item.get('metadata', {}),
                'data_source': student_item.get('data_source', ''),
                'index': i
            }
            score = self.reward_calculator.calculate_score(answer, entry)
            
            # Store detailed result
            detailed_result = {
                'index': i,
                'question': student_item['question'],
                'correct_answer': student_item['answer'],
                'full_input_prompt': student_prompts[i],
                'student_baseline_response': student_response.get('text', ''),
                'extracted_answer': answer,
                'score': score,
                'metadata': student_item.get('metadata', {}),
                'data_source': student_item.get('data_source', ''),
                'seed': student_item.get('seed', 0)
            }
            detailed_baseline_results.append(detailed_result)
            
            if score > 0:  # Assuming binary scoring (1 for correct, 0 for incorrect)
                correct_count += 1
        
        baseline_accuracy = correct_count / total_count if total_count > 0 else 0.0
        return baseline_accuracy, detailed_baseline_results
    
    def process_teacher_responses(self, teacher_responses_path: str,
                                teacher_data: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Process teacher responses and calculate cot_verifier_accuracy metrics.
        This is called AFTER the cot_verifier_accuracy inference has been completed.

        Args:
            teacher_responses_path: Path where original responses are saved
            teacher_data: Teacher responses with cot_verifier_accuracy results

        Returns:
            Dictionary with cot_verifier_accuracy metrics
        """
        try:
            # Calculate scores for each cot_verifier type
            results = {'num_samples': len(teacher_data)}

            for verifier_type in self.requested_cot_verifier_metrics:
                scores = [item.get(f'{verifier_type}_score', 0.0) for item in teacher_data]
                mean_score = sum(scores) / len(scores) if scores else 0.0

                results[f'mean_{verifier_type}_score'] = mean_score
                results[f'{verifier_type}_details'] = {
                    'scores': scores,
                    'min_score': min(scores) if scores else 0.0,
                    'max_score': max(scores) if scores else 0.0,
                    'std_score': float(np.std(scores)) if scores else 0.0
                }

            # Calculate teacher accuracy for comparison
            teacher_scores = [item.get('reward_score', 0.0) for item in teacher_data]
            mean_teacher_accuracy = sum(teacher_scores) / len(teacher_scores) if teacher_scores else 0.0
            results['mean_teacher_accuracy'] = mean_teacher_accuracy

            # Include student baseline accuracy if available
            student_baseline_scores = [item.get('student_accuracy_baseline', 0.0) for item in teacher_data]
            if student_baseline_scores and any(score > 0 for score in student_baseline_scores):
                # Take the first non-zero baseline (should be same for all)
                student_baseline = next((score for score in student_baseline_scores if score > 0), 0.0)
                results['student_accuracy_baseline'] = student_baseline

            logger.info(f"CoT verifier accuracy metrics: {len(teacher_data)} samples")
            for verifier_type in self.requested_cot_verifier_metrics:
                if f'mean_{verifier_type}_score' in results:
                    logger.info(f"  Mean {verifier_type} score: {results[f'mean_{verifier_type}_score']:.4f}")
            logger.info(f"  Mean teacher accuracy: {mean_teacher_accuracy:.4f}")
            if 'student_accuracy_baseline' in results:
                logger.info(f"  Student baseline accuracy: {results['student_accuracy_baseline']:.4f}")

            return results

        except Exception as e:
            logger.error(f"Error processing cot_verifier_accuracy metrics: {e}")
            error_results = {
                'mean_teacher_accuracy': 0.0,
                'student_accuracy_baseline': 0.0,
                'num_samples': 0,
                'error': str(e)
            }
            for verifier_type in self.requested_cot_verifier_metrics:
                error_results[f'mean_{verifier_type}_score'] = 0.0
            return error_results