"""
Generalization Metric

Tests the model's ability to generalize by evaluating on the same task
with modified dataset parameters (e.g., larger coordinate ranges, harder difficulties).
"""

import json
import logging
import os
from typing import Dict, List, Any, Optional
from datetime import datetime

from .base_metric import BaseMetric
from ..utils.data_generator import DataGenerator

logger = logging.getLogger(__name__)


class GeneralizationMetric(BaseMetric):
    """Tests model generalization on same task with modified parameters."""

    def __init__(self, config: Dict[str, Any], reward_calculator, prompt_manager):
        """Initialize generalization metric.

        Args:
            config: Configuration dictionary
            reward_calculator: RewardCalculator instance
            prompt_manager: PromptManager instance
        """
        super().__init__(config, reward_calculator, prompt_manager)

        # Get training task name
        self.training_task = config.get('evaluation', {}).get('teacher_dataset', {}).get('task_name')

        # Initialize generalization dataset if configured
        self.generalization_generator = None
        self.gen_config = None
        self._init_generalization_dataset()

    def _init_metric_config(self):
        """Initialize generalization metric configuration."""
        # Check if generalization is requested
        requested_metrics = self.config.get('evaluation', {}).get('metrics', [])
        self.is_enabled = 'generalization' in requested_metrics

    def _init_generalization_dataset(self):
        """Initialize generalization dataset generator if configured."""
        eval_config = self.config.get('evaluation', {})

        if 'generalization_dataset' not in eval_config:
            logger.info("No generalization_dataset found in config")
            return

        self.gen_config = eval_config['generalization_dataset']

        # Extract task name
        gen_task = self.gen_config.get('task_name')
        if not gen_task:
            logger.warning("No task_name specified in generalization_dataset")
            return

        # Extract task parameters (exclude seed, size, val_start)
        task_params = {
            k: v for k, v in self.gen_config.items()
            if k not in ['task_name', 'seed', 'size', 'val_start']
        }

        logger.info(f"Initializing generalization dataset for {gen_task} with params: {task_params}")

        try:
            self.generalization_generator = DataGenerator(gen_task, task_params)
            logger.info("Generalization dataset generator initialized successfully")
        except Exception as e:
            logger.error(f"Failed to initialize generalization dataset generator: {e}")
            self.generalization_generator = None

    def can_run(self) -> bool:
        """Check if generalization evaluation should be run."""
        return self.is_enabled and self.generalization_generator is not None

    def evaluate(self, teacher_responses: List[Dict[str, Any]], model_manager=None,
                openai_client=None, **kwargs) -> List[Dict[str, Any]]:
        """Run generalization evaluation on modified dataset parameters.

        Args:
            teacher_responses: List of teacher response dictionaries
            model_manager: vLLM model manager for generation
            openai_client: OpenAI client (if using OpenAI)
            **kwargs: Additional arguments (e.g., accuracy_metric)

        Returns:
            Updated teacher_responses with generalization results
        """
        if not self.can_run():
            logger.warning("Generalization evaluation requested but not properly configured")
            return teacher_responses

        accuracy_metric = kwargs.get('accuracy_metric')
        if not accuracy_metric:
            logger.error("AccuracyMetric instance required for generalization evaluation")
            return teacher_responses

        logger.info("="*80)
        logger.info("RUNNING GENERALIZATION EVALUATION")
        logger.info("="*80)

        # Generate generalization test data
        gen_data = self._generate_generalization_data()

        if not gen_data:
            logger.warning("No generalization data generated")
            return teacher_responses

        # Evaluate on generalization data using accuracy metric
        logger.info(f"Evaluating on {len(gen_data)} generalization examples...")
        gen_results = accuracy_metric.evaluate(
            gen_data,
            model_manager=model_manager,
            openai_client=openai_client,
            own_thinking=True  # Use teacher's own thinking
        )

        # Add task name to each result for proper saving
        for result in gen_results:
            result['task'] = self.training_task

        # Calculate generalization accuracy
        gen_stats = self._calculate_generalization_stats(gen_results)

        # Add generalization results to first response (for checkpoint-level metrics)
        if teacher_responses:
            teacher_responses[0]['generalization_results'] = {
                'accuracy': gen_stats['accuracy'],
                'correct_count': gen_stats['correct_count'],
                'total_examples': gen_stats['total_examples'],
                'task_name': self.training_task,
                'dataset_params': {
                    k: v for k, v in self.gen_config.items()
                    if k not in ['task_name', 'seed', 'size', 'val_start']
                }
            }

            # Store full results for saving to separate folder
            teacher_responses[0]['generalization_detailed_results'] = gen_results

        logger.info(f"Generalization accuracy: {gen_stats['accuracy']:.4f} "
                   f"({gen_stats['correct_count']}/{gen_stats['total_examples']})")
        logger.info("="*80)

        return teacher_responses

    def _generate_generalization_data(self) -> List[Dict[str, Any]]:
        """Generate test data with modified parameters for generalization.

        Returns:
            List of data dictionaries with questions and metadata
        """
        # Get teacher_dataset params for size and val_start defaults
        teacher_dataset = self.config.get('evaluation', {}).get('teacher_dataset', {})

        # Use generalization_dataset params if specified, otherwise fall back to teacher_dataset
        seed = self.gen_config.get('seed', teacher_dataset.get('seed', 45))
        size = self.gen_config.get('size', teacher_dataset.get('size', 500))
        val_start = self.gen_config.get('val_start', teacher_dataset.get('val_start', 0))

        logger.info(f"Generating {size} generalization examples with seed={seed}, val_start={val_start}")

        try:
            gen_data = self.generalization_generator.generate_teacher_dataset(
                teacher_seed=seed,
                size=size,
                val_start=val_start
            )
            logger.info(f"Successfully generated {len(gen_data)} generalization examples")
            return gen_data
        except Exception as e:
            logger.error(f"Failed to generate generalization data: {e}")
            return []

    def _calculate_generalization_stats(self, gen_results: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Calculate statistics from generalization evaluation results.

        Args:
            gen_results: List of evaluation results

        Returns:
            Dictionary with accuracy statistics
        """
        if not gen_results:
            return {'accuracy': 0.0, 'correct_count': 0, 'total_examples': 0}

        # Count correct answers
        correct_count = sum(
            1 for result in gen_results
            if result.get('reward_score', 0.0) > 0.0
        )

        total_examples = len(gen_results)
        accuracy = correct_count / total_examples if total_examples > 0 else 0.0

        return {
            'accuracy': accuracy,
            'correct_count': correct_count,
            'total_examples': total_examples
        }

    def process_teacher_responses(self, teacher_responses_path: str,
                                 teacher_data: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Process teacher responses and extract generalization metrics.

        This is called after evaluation to format metrics for saving.

        Args:
            teacher_responses_path: Path where original responses are saved
            teacher_data: Teacher responses with generalization results

        Returns:
            Dictionary with generalization metrics
        """
        if not teacher_data:
            return {}

        # Extract generalization results from first response
        gen_results = teacher_data[0].get('generalization_results', {})

        if not gen_results:
            return {}

        metrics = {
            'generalization_accuracy': gen_results.get('accuracy', 0.0),
            'generalization_correct_count': gen_results.get('correct_count', 0),
            'generalization_total_examples': gen_results.get('total_examples', 0),
            'generalization_task': gen_results.get('task_name', ''),
            'generalization_dataset_params': gen_results.get('dataset_params', {}),
            'num_samples': gen_results.get('total_examples', 0)
        }

        logger.info(f"Generalization metrics processed: accuracy={metrics['generalization_accuracy']:.4f}")

        return metrics

    def to_dict(self) -> Dict[str, Any]:
        """Convert metric to dictionary format.

        Returns:
            Dictionary representation of the metric
        """
        return {
            'metric_name': 'generalization',
            'training_task': self.training_task,
            'is_enabled': self.is_enabled,
            'has_generator': self.generalization_generator is not None,
            'config': self.gen_config
        }

    def print_summary(self) -> None:
        """Print a summary of the generalization metric configuration."""
        print("\n=== GENERALIZATION METRIC SUMMARY ===")
        print(f"Training task: {self.training_task}")
        print(f"Enabled: {self.is_enabled}")
        print(f"Generator initialized: {self.generalization_generator is not None}")

        if self.gen_config:
            print(f"Generalization config: {self.gen_config}")

        print("=" * 50)
