import re
from typing import List, Dict, Optional
from tqdm.auto import tqdm

from src.feature import Feature, Interpretation
from src.builder import PromptBuilder
from src.engines import InferenceEngine


class FeatureInterpreter:
    """Generates multiple interpretations for features using configured engine and prompt builder."""
    
    def __init__(
        self,
        inference_engine: InferenceEngine,
        prompt_builder: PromptBuilder,
        num_rounds: int = 3,
        batch_size: int = 4,
        show_progress: bool = True
    ):
        """
        Initialize the interpretation generator.
        
        Args:
            inference_engine: Engine handling model inference
            prompt_builder: Builder for creating interpretation prompts
            num_rounds: Number of interpretations to generate per feature
            batch_size: Number of features to process per batch
            show_progress: Whether to display progress bar
        """
        self.inference_engine = inference_engine
        self.prompt_builder = prompt_builder
        self.num_rounds = num_rounds
        self.batch_size = batch_size
        self.show_progress = show_progress

    def run(
        self,
        features: List[Feature],
        **generation_kwargs
    ) -> Dict[int, List[str]]:
        """
        Generate multiple interpretations for each feature.
        
        Args:
            features: List of features to interpret
            **generation_kwargs: Additional generation parameters for the engine
            
        Returns:
            Dictionary mapping feature indices to list of interpretations
        """
        results = {feature.index: [] for feature in features}
        
        # Prepare all prompts with multiple rounds
        all_prompts = []
        feature_indices = []
        
        for feature in features:
            messages = self._build_messages(feature)
            for _ in range(self.num_rounds):
                all_prompts.append(messages)
                feature_indices.append(feature.index)

        # Process in batches
        with tqdm(total=len(all_prompts), disable=not self.show_progress) as pbar:
            for batch_start in range(0, len(all_prompts), self.batch_size):
                batch_prompts = all_prompts[batch_start:batch_start + self.batch_size]
                batch_indices = feature_indices[batch_start:batch_start + self.batch_size]

                # Generate responses using engine
                responses = self.inference_engine.generate(
                    prompts=batch_prompts,
                    **generation_kwargs
                )
                
                # Process and store results
                for idx, response in zip(batch_indices, responses):
                    explanation = self._extract_explanation(response)
                    results[idx].append({'response': response, 'explanation': explanation})
                
                pbar.update(len(batch_prompts))

        return results

    def _build_messages(self, feature: Feature) -> List[Dict[str, str]]:
        """Build chat messages for interpretation prompt."""
        prompt = self.prompt_builder.build_prompt(feature)
        return [
            {"role": "system", "content": prompt['system']},
            {"role": "user", "content": prompt['user']}
        ]

    def _extract_explanation(self, response: str) -> str:
        """
        Improved explanation extraction with:
        1. Handling variations of [EXPLANATION]: or EXPLANATION:
        2. Capturing content after the marker, including text in angle brackets.
        3. Returning the original response if no valid pattern is found.
        """
        # Regex to match [EXPLANATION]: or EXPLANATION:, capturing the content after it
        explanation_pattern = r'(?:\[?EXPLANATION\]?:?)\s*(.*?)(?=\n|$)'
        
        # Find all matches of the pattern
        matches = re.findall(explanation_pattern, response, re.DOTALL)
        
        # Combine all matches into a single string, separated by spaces
        raw_explanations = " ".join(matches).strip()
        
        # Cleanup steps
        cleaned = (
            raw_explanations
            # Remove any remaining editorial comments in angle brackets
            .replace('<', '').replace('>', '')
            # Remove trailing punctuation
            .rstrip('.:;')
            # Normalize whitespace
            .replace('\n', ' ').replace('  ', ' ')
        )
        
        # If no valid pattern is found, return the original response
        if not cleaned:
            return None
        
        return cleaned