import numpy as np
from typing import Dict, List

from src.dataset_processing.common.models.dataset_result import DatasetResult
from src.dataset_processing.factory.dataset_factory import DatasetFactory
from src.loggers.setup_logging import setup_logging
from src.reliability_eval.common.enums.score_types import DatasetMetricTypes, QuestionAggregateTypes
from src.reliability_eval.common.models.scores import QuestionAggregateScores
from src.reliability_eval.pipeline.evaluation_pipelines.accessors import ACCURACY_ACCESS_CONFIGS, SCORE_ACCESS_CONFIGS, TOKEN_ID_ACCESS_CONFIGS
from src.reliability_eval.pipeline.evaluation_pipelines.registry import EVALUATION_PIPELINES_DICT
from src.reliability_eval.pipeline.evaluation_pipelines.types import PipelineType
from src.reliability_eval.evaluator.accessor import ScoreAccessor
from src.reliability_eval.pipeline.calculator.metrics import MetricsCalculator
from src.reliability_eval.pipeline.config import BatchConfig, IntegratedPipelineConfig
from src.reliability_eval.pipeline.context import EvaluationContext, MetricsConfig
from src.reliability_eval.pipeline.excel_writer import ExcelWriter
from src.reliability_eval.pipeline.processor.batch import DataLoader
from src.reliability_eval.pipeline.processor.merger import ScoreMerger
from src.reliability_eval.pipeline.processor.processor import BatchProcessor
from src.reliability_eval.pipeline.results import EvaluationResults
from src.reliability_eval.utils.model_loading import load_model_for_evaluation


logger = setup_logging()

class IntegratedEvaluationPipeline:
    """Manages the complete evaluation pipeline."""
    
    def __init__(self):
        """Initialize pipeline components."""
        self.metrics_calculator = MetricsCalculator(
            MetricsConfig(metric_types=DatasetMetricTypes, round_accuracy=True)
        )
        self.excel_writer = ExcelWriter()

    def run_evaluation(
        self,
        config: IntegratedPipelineConfig
    ) -> EvaluationResults:
        """Runs complete evaluation pipeline."""
        logger.info("Starting evaluation pipeline")
        try:
            dataset = self._prepare_dataset(config)
            dataloader = self._prepare_dataloader(dataset, config.batch_config)
            evaluation_context = self._initialize_context(config)
            scores = self._process_batches(dataloader, evaluation_context)
            results = self._compute_final_results(scores, config)
            # if config.num_excel_entries > 0:
            #     # Generate Excel report if specified
            #     logger.info("Generating Excel report")
            #     self.excel_writer._generate_evaluation_excel(
            #         results=results,
            #         config=config,
            #         dataset_entries=dataset.entries,
            #         num_max_entries=config.num_excel_entries
            #     )
            
            logger.info("Evaluation pipeline completed successfully")
            return results
            
        except Exception as e:
            logger.error(f"Pipeline execution failed: {str(e)}")
            raise
            
    def _prepare_dataset(self, config: IntegratedPipelineConfig) -> DataLoader:
        """Prepares dataset for processing."""
        logger.info(f"Processing dataset of type {config.dataset_config.dataset_type}")
        processor = DatasetFactory.create_processor(
            dataset_type=config.dataset_config.dataset_type
        )
        
        return processor.process_dataset(config.dataset_config)
    
    def _prepare_dataloader(self, dataset: DatasetResult, batch_config: BatchConfig) -> DataLoader:
        """Prepares DataLoader for processing."""
        return DataLoader(
            queries=[entry.question for entry in dataset.entries],
            answers=np.array([entry.answer for entry in dataset.entries]),
            batch_config=batch_config
        )
        
    def _merge_batch_scores(
        self,
        accumulated: Dict[str, QuestionAggregateScores],
        new_scores: Dict[str, QuestionAggregateScores]
    ) -> None:
        """Merges new batch scores into accumulated scores."""
        logger.debug("Merging batch scores")
        for pipeline_name, scores in new_scores.items():
            if pipeline_name not in accumulated:
                accumulated[pipeline_name] = scores
                continue
            
            # Use ScoreAggregator to merge all levels of scores
            ScoreMerger.merge_question_scores(accumulated[pipeline_name], scores)

    def _extract_scores_and_metrics(
        self,
        accessor: ScoreAccessor,
        pipeline_types: List[PipelineType]
    ) -> EvaluationResults:
        """Extracts final scores and computes metrics."""
        logger.debug("Extracting scores and computing metrics")
        final_scores = {}
        metrics_dict = {}
        
        # Get token info first (they're pipeline-independent)
        for pipeline_type in pipeline_types:
            config = TOKEN_ID_ACCESS_CONFIGS[pipeline_type]
            
            if final_scores.get('token_info') is None:
                token_info = accessor.get_token_info_tuples(config.pipeline_name)
                if token_info is not None:
                    final_scores['token_info'] = token_info
                    
            if final_scores.get('full_text') is None:
                full_text = accessor.get_full_decoded_text(config.pipeline_name)
                if full_text is not None:
                    final_scores['full_text'] = full_text
            
            if final_scores.get('effective_lengths') is None:
                effective_lengths = accessor.get_effective_lengths(config.pipeline_name)
                if effective_lengths is not None:
                    final_scores['effective_lengths'] = effective_lengths
                    
            if final_scores.get('padded_sequences') is None:
                padded_sequences = accessor.get_padded_sequences(config.pipeline_name)
                if padded_sequences is not None:
                    final_scores['padded_sequences'] = padded_sequences
                    
            if final_scores.get('semantic_entropy') is None:
                semantic_entropy = accessor.get_semantic_entropy_scores(config.pipeline_name)
                if semantic_entropy is not None:
                    final_scores['semantic_entropy'] = semantic_entropy
                    
            if final_scores.get('token_info') is not None and final_scores.get('full_text') is not None and final_scores.get('effective_lengths') is not None and final_scores.get('padded_sequences') is not None and final_scores.get('semantic_entropy') is not None:
                break
                
        # Handle accuracy scores (rest of the existing code...)
        accuracy_configs = {pipeline_type: ACCURACY_ACCESS_CONFIGS[pipeline_type] for pipeline_type in pipeline_types}
        accuracy_values = None
        for config in accuracy_configs.values():
            if config.question_aggregate_type == QuestionAggregateTypes.ACCURACY:
                accuracy_values = accessor.get_score(
                    pipeline_name=config.pipeline_name,
                    token_score_type=None,
                    sequence_aggregate_type=None,
                    question_aggregate_type=QuestionAggregateTypes.ACCURACY
                )
                final_scores["accuracy"] = accuracy_values
                break
                
        # Process pipeline scores and metrics (existing code...)
        score_configs = {pipeline_type: SCORE_ACCESS_CONFIGS[pipeline_type] for pipeline_type in pipeline_types}
        for config in score_configs.values():
            if config.question_aggregate_type != QuestionAggregateTypes.ACCURACY:
                score = accessor.get_score(
                    pipeline_name=config.pipeline_name,
                    token_score_type=config.token_score_type,
                    sequence_aggregate_type=config.sequence_aggregate_type,
                    question_aggregate_type=config.question_aggregate_type
                )
                final_scores[config.pipeline_name] = score
                
                if config.pipeline_name in EVALUATION_PIPELINES_DICT:
                    metrics_dict[config.pipeline_name] = self.metrics_calculator.compute_metrics(
                        score, accuracy_values
                    )
                        
        return EvaluationResults(scores=final_scores, metrics=metrics_dict)

    def _initialize_context(self, config: IntegratedPipelineConfig) -> EvaluationContext:
        """Initializes evaluation context with model and configurations."""
        logger.info(f"Loading model {config.model_identifier}")
        model, tokenizer = load_model_for_evaluation(
            model_identifier=config.model_identifier,
            device=config.device,
            max_memory=config.max_memory,
            apply_compile=config.apply_compile,
        )
        
        return EvaluationContext(
            model=model,
            tokenizer=tokenizer,
            device=config.device,
            exp_id=config.exp_id,
            model_identifier=config.model_identifier,
            batch_config=config.batch_config,
            generation_config=config.generation_config,
            evaluation_config=config.evaluation_config,
            metrics_config=self.metrics_calculator.config
        )

    def _process_batches(
        self, 
        dataloader: DataLoader, 
        context: EvaluationContext
    ) -> Dict[str, QuestionAggregateScores]:
        """Processes all batches and aggregates scores."""
        logger.info("Processing batches")
        processor = BatchProcessor(context)
        accumulated_scores = {}
        
        for batch_idx, batch in enumerate(dataloader):
            batch_scores = processor.process_batch(batch)
            if not accumulated_scores:
                accumulated_scores = batch_scores
            else:
                self._merge_batch_scores(accumulated_scores, batch_scores)
            logger.debug(f"Completed batch {batch_idx + 1}")
            
        return accumulated_scores

    def _compute_final_results(
        self,
        evaluation_scores: Dict[str, QuestionAggregateScores],
        config: IntegratedPipelineConfig
    ) -> EvaluationResults:
        """Computes final evaluation results including metrics."""
        logger.info("Computing final results")
        accessor = ScoreAccessor(evaluation_scores)
        return self._extract_scores_and_metrics(accessor, config.pipeline_types)
