"""Enhanced chunked checkpoint system with chunk-based saving and continual evaluation."""

import json
import logging
import shutil
from pathlib import Path
from typing import List, Dict, Any, Optional, Set, Tuple
from datetime import datetime

from .parallel_evaluator import EvaluationResult
from ..config.unified_config import EvaluationParams


class ChunkedCheckpointSystem:
    """
    A checkpoint system that saves results in chunks with complete information.
    Supports continual evaluation at both timestamp and checkpoint levels.
    """
    
    def __init__(
        self,
        checkpoint_dir: Path,
        chunk_size: int = 20,
        logger: Optional[logging.Logger] = None
    ):
        """Initialize the chunked checkpoint system.
        
        Args:
            checkpoint_dir: Base directory for checkpoints
            chunk_size: Number of samples per chunk
            logger: Logger instance
        """
        self.checkpoint_dir = Path(checkpoint_dir)
        self.chunk_size = chunk_size
        self.logger = logger or logging.getLogger(__name__)
        
        # Current evaluation state
        self.current_results: List[EvaluationResult] = []
        self.current_chunk_index = 0
        self.completed_problem_ids: Set[str] = set()
        
        # Create checkpoint directory
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        # Load existing checkpoints
        self._load_existing_chunks()
    
    def _load_existing_chunks(self):
        """Load existing chunk data to determine what has been completed."""
        chunk_dirs = [d for d in self.checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith('chunk_')]
        
        if not chunk_dirs:
            self.logger.info("No existing chunks found, starting fresh")
            return
        
        # Sort chunk directories by index
        chunk_dirs.sort(key=lambda x: int(x.name.split('_')[1]))
        
        # Find the highest complete chunk
        complete_chunks = []
        for chunk_dir in chunk_dirs:
            chunk_file = chunk_dir / "chunk_results.json"
            if chunk_file.exists():
                try:
                    with open(chunk_file, 'r') as f:
                        chunk_data = json.load(f)
                        complete_chunks.append((chunk_dir, chunk_data))
                        # Add problem IDs to completed set
                        for result in chunk_data.get('results', []):
                            if result.get('success', False):
                                # Use task_id (which is now question_id) for tracking
                                self.completed_problem_ids.add(result.get('task_id'))
                                # Additionally, check the conversation for error codes, as this would mean the run failed
                                # despite the success flag.
                                conversation_str = json.dumps(result.get('metrics', {}).get('full_conversation', []))
                                if "Error code: 503" not in conversation_str:
                                    # Use task_id (which is now question_id) for tracking
                                    self.completed_problem_ids.add(result.get('task_id'))
                except Exception as e:
                    self.logger.warning(f"Error loading chunk {chunk_dir}: {e}")
        
        if complete_chunks:
            # Set next chunk index
            last_chunk_dir, _ = complete_chunks[-1]
            last_chunk_index = int(last_chunk_dir.name.split('_')[1])
            self.current_chunk_index = last_chunk_index + 1
            
            self.logger.info(f"Loaded {len(complete_chunks)} existing chunks")
            self.logger.info(f"Completed problem IDs: {len(self.completed_problem_ids)}")
            self.logger.info(f"Next chunk index: {self.current_chunk_index}")
    
    def should_skip_problem(self, question_id: str) -> bool:
        """Check if a problem should be skipped because it's already completed.
        
        Args:
            question_id: Question ID of the problem to check
            
        Returns:
            True if the problem should be skipped
        """
        return question_id in self.completed_problem_ids
    
    def add_result(self, result: EvaluationResult) -> bool:
        """Add a result to the current chunk.
        
        Args:
            result: Evaluation result to add
            
        Returns:
            True if chunk was saved (reached chunk_size), False otherwise
        """
        self.current_results.append(result)
        
        # Check if we need to save the current chunk
        if len(self.current_results) >= self.chunk_size:
            self._save_current_chunk()
            return True
        
        return False
    
    def _save_current_chunk(self):
        """Save the current chunk to disk."""
        if not self.current_results:
            return
        
        chunk_dir = self.checkpoint_dir / f"chunk_{self.current_chunk_index}"
        chunk_dir.mkdir(parents=True, exist_ok=True)
        
        # Prepare chunk data
        chunk_data = {
            'chunk_index': self.current_chunk_index,
            'chunk_size': len(self.current_results),
            'timestamp': datetime.now().isoformat(),
            'results': []
        }
        
        # Convert results to serializable format
        for result in self.current_results:
            result_dict = {
                'task_id': result.task_id,
                'problem_index': result.problem_index,
                'success': result.success,
                'answer': result.answer,
                'ground_truth': result.ground_truth,
                'is_correct': result.is_correct,
                'error': str(result.error) if result.error else None,
                'metrics': result.metrics,
                'timestamp': result.timestamp,
                'duration': result.duration,
                'attempt': result.attempt
            }
            chunk_data['results'].append(result_dict)
        
        # Save chunk data
        chunk_file = chunk_dir / "chunk_results.json"
        with open(chunk_file, 'w') as f:
            json.dump(chunk_data, f, indent=2)
        
        # Save chunk metadata
        metadata_file = chunk_dir / "chunk_metadata.json"
        metadata = {
            'chunk_index': self.current_chunk_index,
            'total_results': len(self.current_results),
            'successful_results': sum(1 for r in self.current_results if r.success),
            'failed_results': sum(1 for r in self.current_results if not r.success),
            'problem_indices': [r.problem_index for r in self.current_results],  # Keep for backward compatibility
            'question_ids': [r.task_id for r in self.current_results],  # New field with actual question IDs
            'created_timestamp': datetime.now().isoformat()
        }
        
        with open(metadata_file, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        self.logger.info(f"Saved chunk {self.current_chunk_index} with {len(self.current_results)} results")
        
        # Update completed problem IDs
        for result in self.current_results:
            if result.success:
                self.completed_problem_ids.add(result.task_id)
        
        # Reset for next chunk
        self.current_results = []
        self.current_chunk_index += 1
    
    def finalize_evaluation(self, setting: EvaluationParams) -> List[EvaluationResult]:
        """Finalize the evaluation by saving any remaining results and combining all chunks.
        
        Args:
            setting: Evaluation setting used for this run
            
        Returns:
            Combined list of all results from all chunks
        """
        # Save any remaining results in current chunk
        if self.current_results:
            self._save_current_chunk()
        
        # Save evaluation metadata
        eval_metadata = {
            'setting_id': setting.setting_id,
            'model_name': setting.model_name,
            'total_chunks': self.current_chunk_index,
            'chunk_size': self.chunk_size,
            'finalization_timestamp': datetime.now().isoformat()
        }
        
        metadata_file = self.checkpoint_dir / "evaluation_metadata.json"
        with open(metadata_file, 'w') as f:
            json.dump(eval_metadata, f, indent=2)
        
        # Combine all results from all chunks
        return self._combine_all_chunks()
    
    def _combine_all_chunks(self) -> List[EvaluationResult]:
        """Combine results from all chunks into a single list.
        
        Returns:
            Combined list of all results
        """
        all_results = []
        
        # Get all chunk directories
        chunk_dirs = [d for d in self.checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith('chunk_')]
        chunk_dirs.sort(key=lambda x: int(x.name.split('_')[1]))
        
        for chunk_dir in chunk_dirs:
            chunk_file = chunk_dir / "chunk_results.json"
            if chunk_file.exists():
                try:
                    with open(chunk_file, 'r') as f:
                        chunk_data = json.load(f)
                        
                    # Convert back to EvaluationResult objects
                    for result_dict in chunk_data.get('results', []):
                        result = EvaluationResult(
                            task_id=result_dict.get('task_id', ''),
                            problem_index=result_dict.get('problem_index', 0),
                            model_name=result_dict.get('model_name', ''),
                            success=result_dict.get('success', False),
                            answer=result_dict.get('answer'),
                            ground_truth=result_dict.get('ground_truth'),
                            is_correct=result_dict.get('is_correct'),
                            error=result_dict.get('error'),
                            metrics=result_dict.get('metrics', {}),
                            timestamp=result_dict.get('timestamp', ''),
                            duration=result_dict.get('duration', 0.0),
                            attempt=result_dict.get('attempt', 0)
                        )
                        all_results.append(result)
                        
                except Exception as e:
                    self.logger.error(f"Error loading chunk {chunk_dir}: {e}")
        
        self.logger.info(f"Combined {len(all_results)} results from {len(chunk_dirs)} chunks")
        return all_results
    
    def get_recovery_results(self) -> List[EvaluationResult]:
        """Get all results from existing chunks for recovery purposes.
        
        Returns:
            List of all results that can be recovered from chunks
        """
        return self._combine_all_chunks()
    
    def cleanup_old_checkpoints(self, keep_chunks: bool = True):
        """Clean up old checkpoint files, optionally keeping chunk data.
        
        Args:
            keep_chunks: Whether to keep chunk directories (default: True)
        """
        if not keep_chunks:
            # Remove all chunk directories
            for chunk_dir in self.checkpoint_dir.iterdir():
                if chunk_dir.is_dir() and chunk_dir.name.startswith('chunk_'):
                    shutil.rmtree(chunk_dir)
                    self.logger.info(f"Removed chunk directory: {chunk_dir}")
        
        # Remove old-style checkpoint files if they exist
        old_checkpoint_file = self.checkpoint_dir / "checkpoint.json"
        if old_checkpoint_file.exists():
            old_checkpoint_file.unlink()
            self.logger.info("Removed old-style checkpoint file")


class ContinualEvaluationManager:
    """
    Manages continual evaluation across timestamp folders and checkpoint recovery.
    """
    
    def __init__(
        self,
        setting_base_dir: Path,
        logger: Optional[logging.Logger] = None
    ):
        """Initialize continual evaluation manager.
        
        Args:
            setting_base_dir: Base directory for a setting (e.g., test_gpt_4o/)
            logger: Logger instance
        """
        self.setting_base_dir = Path(setting_base_dir)
        self.logger = logger or logging.getLogger(__name__)
        self.most_updated_dir = self.setting_base_dir / "most_updated"
        
    def get_completed_problem_ids(self) -> Tuple[Set[str], Dict[str, Dict[str, Any]]]:
        """Get all completed problem IDs and their results from existing evaluations.
        
        We first prefer the `most_updated` folder if present, since it
        consolidates the latest results across timestamps. If it's not
        available, we fall back to scanning timestamp folders directly.
        
        Returns:
            Tuple of:
              - Set of question IDs (task_ids) that have been completed
              - Dict mapping question_id -> latest completed result dict
        """
        completed_ids: Set[str] = set()
        completed_results_by_id: Dict[str, Dict[str, Any]] = {}
        results_data: List[Dict[str, Any]] = []

        # Prefer consolidated results if available
        if self.most_updated_dir.exists():
            results_file = self.most_updated_dir / "results.json"
            if results_file.exists():
                try:
                    with open(results_file, 'r') as f:
                        results_data = json.load(f)
                    for result in results_data:
                        question_id = result.get('task_id')
                        answer = result.get('answer')
                        if answer is None:
                            continue
                        if not question_id:
                            problem_idx = result.get('problem_index')
                            if problem_idx is not None:
                                question_id = f"problem_{problem_idx}"
                        if question_id and result.get('success', False):
                            completed_ids.add(question_id)
                            completed_results_by_id[question_id] = result
                    self.logger.info(f"Using most_updated: found {len(completed_ids)} completed problems")
                except Exception as e:
                    self.logger.warning(f"Error reading most_updated results: {e}")
        
        # Scan timestamp folders directly (don't read from most_updated)
        timestamp_dirs = [d for d in self.setting_base_dir.iterdir() 
                         if d.is_dir() and d.name != "most_updated"]
        
        if not timestamp_dirs:
            self.logger.info("No existing timestamp folders found")
            return completed_ids, completed_results_by_id
        
        # Build question_id -> result mapping from all timestamps
        problem_results = {}
        
        for timestamp_dir in timestamp_dirs:
            results_file = timestamp_dir / "results.json"
            if not results_file.exists():
                # Check for chunk results
                checkpoints_dir = timestamp_dir / "checkpoints"
                if checkpoints_dir.exists():
                    chunk_results = self._load_chunk_results(checkpoints_dir)
                    for result in chunk_results:
                        # Try to get task_id (question_id) first, fall back to problem_index for backward compatibility
                        question_id = result.get('task_id')
                        answer = result.get('answer')
                        if answer is None:
                            continue
                        if not question_id:
                            # Fallback to problem_index for backward compatibility
                            problem_idx = result.get('problem_index')
                            if problem_idx is not None:
                                question_id = f"problem_{problem_idx}"
                        
                        if question_id:
                            # Use timestamp directory name as proxy for recency
                            if (question_id not in problem_results or 
                                timestamp_dir.name > problem_results[question_id]['timestamp_dir']):
                                problem_results[question_id] = {
                                    'result': result,
                                    'timestamp_dir': timestamp_dir.name
                                }
                continue
            
            try:
                with open(results_file, 'r') as f:
                    results_data = json.load(f)
                    
                for result in results_data:
                    # Try to get task_id (question_id) first, fall back to problem_index for backward compatibility
                    question_id = result.get('task_id')
                    answer = result.get('answer')
                    if answer is None:
                        continue
                    if not question_id:
                        # Fallback to problem_index for backward compatibility
                        problem_idx = result.get('problem_index')
                        if problem_idx is not None:
                            question_id = f"problem_{problem_idx}"
                    
                    if question_id:
                        # Keep most recent result for each question_id
                        if (question_id not in problem_results or 
                            timestamp_dir.name > problem_results[question_id]['timestamp_dir']):
                            problem_results[question_id] = {
                                'result': result,
                                'timestamp_dir': timestamp_dir.name
                            }
            except Exception as e:
                self.logger.warning(f"Error reading results from {timestamp_dir}: {e}")
        
        # Extract completed question IDs and their results
        # Prefer entries from most_updated; do not overwrite them with timestamp folder data
        for question_id, data in problem_results.items():
            if data['result'].get('success', False):
                completed_ids.add(question_id)
                if question_id not in completed_results_by_id:
                    completed_results_by_id[question_id] = data['result']
        
        self.logger.info(f"Found {len(completed_ids)} completed problems from {len(timestamp_dirs)} timestamp folders")
        self.completed_results_by_id = completed_results_by_id
        self.completed_ids = completed_ids
        return completed_ids, completed_results_by_id
    
    def _load_chunk_results(self, checkpoints_dir: Path) -> List[Dict[str, Any]]:
        """Load results from chunk directories in checkpoints folder.
        
        Args:
            checkpoints_dir: Directory containing chunk folders
            
        Returns:
            List of result dictionaries from all chunks
        """
        chunk_results = []
        
        chunk_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir() and d.name.startswith('chunk_')]
        chunk_dirs.sort(key=lambda x: int(x.name.split('_')[1]))
        
        for chunk_dir in chunk_dirs:
            chunk_file = chunk_dir / "chunk_results.json"
            if chunk_file.exists():
                try:
                    with open(chunk_file, 'r') as f:
                        chunk_data = json.load(f)
                        chunk_results.extend(chunk_data.get('results', []))
                except Exception as e:
                    self.logger.warning(f"Error loading chunk {chunk_dir}: {e}")
        
        return chunk_results
    
    def update_most_updated_folder(self):
        """Update the most_updated folder with latest results from all timestamps."""
        # Get all timestamp directories
        timestamp_dirs = [d for d in self.setting_base_dir.iterdir() 
                         if d.is_dir() and d.name != "most_updated"]
        
        if not timestamp_dirs:
            self.logger.info("No timestamp folders found to update")
            return
        
        # Build problem_id -> latest result mapping
        problem_results = {}
        latest_config = None
        latest_metrics_template = None
        
        for timestamp_dir in sorted(timestamp_dirs, key=lambda x: x.name):
            # Load results
            results_file = timestamp_dir / "results.json"
            results_data = []
            
            if results_file.exists():
                try:
                    with open(results_file, 'r') as f:
                        results_data = json.load(f)
                    self.logger.debug(f"Loaded {len(results_data)} results from {timestamp_dir.name}/results.json")
                except Exception as e:
                    self.logger.warning(f"Error reading results from {timestamp_dir}: {e}")
            else:
                # Try to load from chunks
                checkpoints_dir = timestamp_dir / "checkpoints"
                if checkpoints_dir.exists():
                    chunk_results = self._load_chunk_results(checkpoints_dir)
                    results_data = chunk_results
                    if chunk_results:
                        self.logger.info(f"Loaded {len(chunk_results)} results from chunks in {timestamp_dir.name} (no final results.json)")
                    else:
                        self.logger.warning(f"No results found in chunks for {timestamp_dir.name}")
                else:
                    self.logger.warning(f"No results.json or checkpoints found in {timestamp_dir.name}")
            
            # Update problem results with latest data
            for result in results_data:
                # Try to get task_id (question_id) first, fall back to problem_index for backward compatibility
                question_id = result.get('task_id')
                answer = result.get('answer')
                if answer is None:
                    continue
                if not question_id:
                    # Fallback to problem_index for backward compatibility
                    problem_idx = result.get('problem_index')
                    if problem_idx is not None:
                        question_id = f"problem_{problem_idx}"
                
                if question_id:
                    # Keep most recent result (timestamp_dir name is sortable)
                    if (question_id not in problem_results or 
                        timestamp_dir.name > problem_results[question_id]['timestamp_dir']):
                        problem_results[question_id] = {
                            'result': result,
                            'timestamp_dir': timestamp_dir.name
                        }
            
            # Keep latest config and metrics template
            config_file = timestamp_dir / "config.json"
            if config_file.exists() and (latest_config is None or timestamp_dir.name > latest_config[1]):
                try:
                    with open(config_file, 'r') as f:
                        latest_config = (json.load(f), timestamp_dir.name)
                except Exception as e:
                    self.logger.warning(f"Error reading config from {timestamp_dir}: {e}")
            
            metrics_file = timestamp_dir / "metrics.json"
            if metrics_file.exists() and (latest_metrics_template is None or timestamp_dir.name > latest_metrics_template[1]):
                try:
                    with open(metrics_file, 'r') as f:
                        latest_metrics_template = (json.load(f), timestamp_dir.name)
                except Exception as e:
                    self.logger.warning(f"Error reading metrics from {timestamp_dir}: {e}")
        
        # Create most_updated folder
        self.most_updated_dir.mkdir(parents=True, exist_ok=True)
        
        # Save combined results, adding timestamp_dir info for metrics calculation
        combined_results = []
        for question_id in sorted(problem_results.keys()):
            result = problem_results[question_id]['result'].copy()
            result['timestamp_dir'] = problem_results[question_id]['timestamp_dir']
            combined_results.append(result)
        
        results_file = self.most_updated_dir / "results.json"
        with open(results_file, 'w') as f:
            json.dump(combined_results, f, indent=2)
        
        # Save latest config with system prompt
        if latest_config:
            config_data = latest_config[0].copy()
            # Ensure system prompt is included if not already present
            if 'system_prompt' not in config_data and hasattr(self, 'unified_config'):
                config_data['system_prompt'] = self.unified_config.system_prompt
            config_file = self.most_updated_dir / "config.json"
            with open(config_file, 'w') as f:
                json.dump(config_data, f, indent=2)
        
        # Calculate and save new metrics
        if combined_results:
            metrics = self._calculate_combined_metrics(combined_results, latest_metrics_template[0] if latest_metrics_template else None)
            metrics_file = self.most_updated_dir / "metrics.json"
            with open(metrics_file, 'w') as f:
                json.dump(metrics, f, indent=2)
        
        self.logger.info(f"Updated most_updated folder with {len(combined_results)} results from {len(timestamp_dirs)} timestamp folders")
    
    def _calculate_combined_metrics(self, results: List[Dict[str, Any]], template: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """Calculate metrics for combined results.
        
        Args:
            results: List of result dictionaries
            template: Optional template metrics to use as base
            
        Returns:
            Calculated metrics dictionary
        """
        total_tasks = len(results)
        successful_tasks = sum(1 for r in results if r.get('success', False))
        correct_tasks = sum(1 for r in results if r.get('is_correct', False))
        follow_format_tasks = sum(1 for r in results if r.get('metrics', {}).get('follow_format', False))
        
        # Calculate search completion metrics (like in unified_orchestrator.py)
        all_docs_found = [r for r in results if r.get('metrics', {}).get('search_complete', False)]
        not_all_docs_found = [r for r in results if not r.get('metrics', {}).get('search_complete', False)]
        
        metrics = {
            'total_tasks': total_tasks,
            'successful_tasks': successful_tasks,
            'failed_tasks': total_tasks - successful_tasks,
            'success_rate': successful_tasks / total_tasks if total_tasks > 0 else 0.0,
            'correct_answers': correct_tasks,
            'accuracy': correct_tasks / successful_tasks if successful_tasks > 0 else 0.0,
            'follow_format_count': follow_format_tasks,
            'follow_format_rate': follow_format_tasks / successful_tasks if successful_tasks > 0 else 0.0,
            'average_duration': sum(r.get('duration', 0.0) for r in results) / total_tasks if total_tasks > 0 else 0.0,
            'total_duration': sum(r.get('duration', 0.0) for r in results),
            
            # Subgroup analysis based on search completion
            'all_docs_found_tasks': len(all_docs_found),
            'all_docs_found_accuracy': (
                sum(1 for r in all_docs_found if r.get('is_correct', False)) / len(all_docs_found)
                if all_docs_found else 0.0
            ),
            'not_all_docs_found_tasks': len(not_all_docs_found),
            'not_all_docs_found_accuracy': (
                sum(1 for r in not_all_docs_found if r.get('is_correct', False)) / len(not_all_docs_found)
                if not_all_docs_found else 0.0
            ),
            
            'evaluation_info': {
                'combined_from_timestamps': True,
                'total_timestamp_folders': len(set(r.get('timestamp_dir', 'unknown') for r in results if 'timestamp_dir' in r)),
                'combination_timestamp': datetime.now().isoformat()
            }
        }
        
        # Add template information if available
        if template:
            if 'evaluation_info' in template:
                metrics['evaluation_info'].update({
                    'setting_id': template['evaluation_info'].get('setting_id'),
                    'model': template['evaluation_info'].get('model'),
                    'dataset': template['evaluation_info'].get('dataset'),
                    'embedding': template['evaluation_info'].get('embedding'),
                })
        
        return metrics
