from typing import List, Any, Dict, Optional, Tuple, Union
import random
from tqdm.auto import tqdm

from src.feature import Interpretation
from src.scoring import BaseScoringMethod
from src.engines import InferenceEngine
from src.samplers import BaseSampler, RandomSampler


class Evaluator:
    """Evaluates model performance using specified scoring method and sampling strategy."""
    
    def __init__(
        self,
        inference_engine: InferenceEngine,
        pos_sampler: BaseSampler,
        neg_sampler: Optional[BaseSampler] = None,
        examples_in_prompt: int = 5,
        batch_size: int = 8,
        show_progress: bool = True,
    ):
        """
        Initialize the evaluator with all configuration parameters.
        
        Args:
            inference_engine: Engine used for model inference (handles generation config)
            pos_sampler: Sampler for positive examples
            neg_sampler: Sampler for negative examples (defaults to RandomSampler)
            examples_in_prompt: Number of examples to include in each prompt
            batch_size: Batch size for inference processing
            show_progress: Whether to show progress bar during evaluation
        """
        self.inference_engine = inference_engine
        self.pos_sampler = pos_sampler
        self.neg_sampler = neg_sampler or RandomSampler(sample_size=pos_sampler.sample_size)
        self.examples_in_prompt = examples_in_prompt
        self.batch_size = batch_size
        self.show_progress = show_progress
        
        self.scoring_method = None

    def run(
        self,
        features: List[Any],
        interpretations: List[Union[Interpretation, str]],
        scoring_method: BaseScoringMethod,
        **generation_kwargs
    ) -> Dict[str, float]:
        """
        Evaluate model performance on the given features and interpretations.
        
        Args:
            features: List of features to evaluate
            interpretations: Corresponding interpretations for each feature
            
        Returns:
            Dictionary of evaluation results per feature with metrics and scores
        """
        self.scoring_method = scoring_method
        
        # Prepare all prompts and labels
        prompt_data = self._prepare_all_prompts(features, interpretations)
        
        # Process prompts in batches and collect results
        batch_results = self._process_prompt_batches(prompt_data, **generation_kwargs)
        
        # Compute final metrics for each feature
        return self._compute_final_metrics(features, batch_results)

    def _prepare_all_prompts(
        self,
        features: List[Any],
        interpretations: List[Union[Interpretation, str]],
    ) -> Dict[str, list]:
        """Prepare all prompts for evaluation."""
        prompt_data = {
            'prompts': [],
            'features': [],
            'labels': [],
            'lengths': []
        }

        for feature, interpretation in zip(features, interpretations):
            # Sample and prepare examples
            prepared_examples, labels = self._prepare_examples(feature)
            
            # Create prompts from prepared examples
            self._create_prompts_for_feature(
                feature, 
                interpretation, 
                prepared_examples, 
                labels,
                prompt_data
            )

        return prompt_data

    def _prepare_examples(
        self,
        feature: Any,
    ) -> Tuple[list, list]:
        """Sample and prepare examples for a single feature."""
        pos_examples = self.pos_sampler.sample(feature.positive_examples)
        neg_examples = self.neg_sampler.sample(feature.negative_examples)
        all_examples = pos_examples + neg_examples
        
        prepared_examples, labels = self.scoring_method.prepare_examples(all_examples)
        
        # Shuffle while maintaining alignment
        combined = list(zip(prepared_examples, labels))
        random.shuffle(combined)
        return zip(*combined) if combined else ([], [])

    def _create_prompts_for_feature(
        self,
        feature: Any,
        interpretation: Union[Interpretation, str],
        prepared_examples: list,
        labels: list,
        prompt_data: Dict[str, list]
    ):
        """Create prompts for a single feature."""
        for i in range(0, len(prepared_examples), self.examples_in_prompt):
            examples_for_prompt = prepared_examples[i:i+self.examples_in_prompt]
            labels_for_prompt = labels[i:i+self.examples_in_prompt]

            if isinstance(interpretation, Interpretation):
                interp_text = interpretation.value
            elif isinstance(interpretation, str):
                interp_text = interpretation
            else:
                raise ValueError("Interpretation must be either Interpretation instance or string.")
            
            user_prompt = self.scoring_method.create_prompt(
                interp_text, 
                examples_for_prompt
            )

            prompt_data['prompts'].append(user_prompt)
            prompt_data['features'].append(feature.index)
            prompt_data['labels'].append(labels_for_prompt)
            prompt_data['lengths'].append(len(examples_for_prompt))

    def _process_prompt_batches(
        self,
        prompt_data: Dict[str, list],
        **generation_kwargs
    ) -> Dict[str, Dict[str, list]]:
        """Process prompts in batches and collect model predictions."""
        features_in_data = set(prompt_data['features'])
        results = {f: {'true': [], 'pred': []} for f in features_in_data}

        with tqdm(total=len(prompt_data['prompts']), disable=not self.show_progress) as pbar:
            for batch_start in range(0, len(prompt_data['prompts']), self.batch_size):
                batch_results = self._process_single_batch(
                    prompt_data, 
                    batch_start,
                    **generation_kwargs
                )
                
                for feature, preds, trues in batch_results:
                    results[feature]['pred'].extend(preds)
                    results[feature]['true'].extend(trues)
                
                pbar.update(len(batch_results))

        return results

    def _process_single_batch(
        self,
        prompt_data: Dict[str, list],
        batch_start: int,
        **generation_kwargs
    ) -> List[Tuple[str, list, list]]:
        """Process a single batch of prompts."""
        batch_end = batch_start + self.batch_size
        batch_prompts = prompt_data['prompts'][batch_start:batch_end]
        batch_labels = prompt_data['labels'][batch_start:batch_end]
        batch_features = prompt_data['features'][batch_start:batch_end]
        batch_lengths = prompt_data['lengths'][batch_start:batch_end]

        # Prepare chat format prompts for the engine
        chat_prompts = [[
            {"role": "system", "content": self.scoring_method.system_prompt},
            {"role": "user", "content": prompt}
        ] for prompt in batch_prompts]

        # Generate responses using the inference engine
        outputs = self.inference_engine.generate(prompts=chat_prompts, **generation_kwargs)
        
        # Parse responses and collect results
        batch_results = []
        for out, true_labels, feature, length in zip(
            outputs, batch_labels, batch_features, batch_lengths
        ):
            pred_labels = self.scoring_method.parse_response(out, length)
            if pred_labels is not None:
                batch_results.append((feature, pred_labels, true_labels))
                
        return batch_results

    def _compute_final_metrics(
        self,
        features: List[Any],
        batch_results: Dict[str, Dict[str, list]],
    ) -> Dict[str, float]:
        """Compute final metrics for each feature with missing-feature handling."""
        result = {}
        
        for feature in features:
            # Initialize with empty results if feature missing
            if feature.index not in batch_results:
                print(f"Warning: No results for feature {feature.index}")
                result[feature.index] = {
                    'metrics': None,
                    'score': None
                }
                continue
                
            pred = batch_results[feature.index]['pred']
            true = batch_results[feature.index]['true']
            
            # Handle length mismatches
            if len(pred) != len(true):
                print(f"Label/pred length mismatch for feature {feature.index} "
                      f"({len(true)} vs {len(pred)})")
                result[feature.index] = {
                    'metrics': None,
                    'score': None
                }
                continue
                
            # Compute metrics
            try:
                full_metrics, main_metric = self.scoring_method.compute_metrics(pred, true)
            except Exception as e:  # Broad exception catch for safety
                print(f"Error computing metrics for feature {feature.index}: {str(e)}")
                full_metrics, main_metric = None, None
                
            result[feature.index] = {
                'metrics': full_metrics,
                'score': main_metric
            }
        
        return result