import json
from typing import Dict, List, Any
import logging

from drbench.metrics.base import DrBenchMetric
from drbench.agents.utils import prompt_llm

# Configure logging
logger = logging.getLogger(__name__)

class DRGymReportRelevance(DrBenchMetric):
    def __init__(self, model: str):
        """
        Initialize the DrGym Report Relevance metric.

        Args:
            model: The name of the model to use for scoring
        """
        super().__init__(name="drgym_report_relevance", model=model)
        self.model = model

    def create_prompt(self, key_point: str, answer: str) -> str:
        """Create evaluation prompt for a single key point"""
        return f"""You are given a **single key point** and a **report**.

Your job is to determine whether the report:
- **Supports** the key point (it affirms, explains, or reinforces the point),
- **Omits** the key point (it does not mention or cover this point at all), or
- **Contradicts** the key point (it says something that disagrees with or negates the point).

Carefully read the key point and the report.

Return your answer as a **JSON object** with two fields:
- "label": One of "Supported", "Omitted", or "Contradicted".
- "justification": Brief explanation on why you assigned this label.

Respond strictly in JSON format:
{{"label": "<label>", "justification": "<justification>"}}
Do **not** add any extra commentary or text outside the JSON.

---

Key Point: {key_point}
Report: {answer}
"""

    def evaluate_single_key_point(self, key_point: str, answer: str) -> tuple:
        """Evaluate how well a single key point is addressed in the report"""
        prompt = self.create_prompt(key_point, answer)
        
        try:
            response = prompt_llm(prompt, self.model, temperature=0)
            
            # Parse JSON response
            result = json.loads(response)
            label = result.get('label', 'Omitted')
            justification = result.get('justification', 'No justification provided')
            
            # Validate label
            valid_labels = ["Supported", "Omitted", "Contradicted"]
            if label not in valid_labels:
                logger.warning(f"Invalid label '{label}', defaulting to 'Omitted'")
                label = "Omitted"
            
            return label, justification
            
        except Exception as e:
            logger.error(f"Error evaluating key point: {e}")
            return "Omitted", f"Error during evaluation: {str(e)}"

    def evaluate_answer(self, answer: str, key_points: List[str]) -> Dict:
        """Evaluate how well the answer covers all key points"""
        results = {}
        
        for i, key_point in enumerate(key_points, 1):
            label, justification = self.evaluate_single_key_point(key_point, answer)
            results[f"point_{i}"] = {
                "key_point": key_point,
                "label": label,
                "justification": justification
            }
        
        return results

    def calculate_metrics(self, evaluations: Dict) -> Dict:
        """Calculate support, omitted, and contradicted rates"""
        supported_count = 0
        omitted_count = 0
        contradicted_count = 0
        
        for point_id, evaluation in evaluations.items():
            label = evaluation["label"]
            if label == "Supported":
                supported_count += 1
            elif label == "Omitted":
                omitted_count += 1
            elif label == "Contradicted":
                contradicted_count += 1
        
        total_points = len(evaluations)
        if total_points == 0:
            return {
                "support_rate": 0.0,
                "omitted_rate": 0.0,
                "contradicted_rate": 0.0
            }
        
        support_rate = supported_count / total_points * 100
        omitted_rate = omitted_count / total_points * 100
        contradicted_rate = contradicted_count / total_points * 100
        
        return {
            "support_rate": support_rate,
            "omitted_rate": omitted_rate,
            "contradicted_rate": contradicted_rate,
            "supported_count": supported_count,
            "omitted_count": omitted_count,
            "contradicted_count": contradicted_count,
            "total_points": total_points
        }

    def compute(self, report_dict: Dict[str, Any], task_data=None, eval_data=None) -> dict:
        """
        Compute DrGym report relevance scores using key point evaluation.

        Args:
            report_dict: Dictionary containing 'report_text' and 'report_insights'
            task_data: Task-specific data containing key points
            eval_data: Evaluation data (unused)

        Returns:
            Dict: Standardized result with relevance scores
        """
        report_text = report_dict.get("report_text", "")
        
        # Extract key points from task_data
        if not task_data or "key_points" not in task_data:
            return {
                "score": 0.0,
                "summary": "No key points provided for evaluation.",
                "metric_result": {
                    "supported_points": [],
                    "omitted_points": [],
                    "contradicted_points": [],
                    "support_rate": 0.0,
                    "total_points": 0,
                }
            }
        
        key_points = task_data["key_points"]
        if not key_points:
            return {
                "score": 0.0,
                "summary": "Empty key points list provided.",
                "metric_result": {
                    "supported_points": [],
                    "omitted_points": [],
                    "contradicted_points": [],
                    "support_rate": 0.0,
                    "total_points": 0,
                }
            }
        
        # Evaluate each key point
        evaluations = self.evaluate_answer(report_text, key_points)
        
        # Calculate metrics
        metrics = self.calculate_metrics(evaluations)
        
        # Categorize points by their labels
        supported_points = []
        omitted_points = []
        contradicted_points = []
        
        for point_id, evaluation in evaluations.items():
            point_info = {
                "key_point": evaluation["key_point"],
                "justification": evaluation["justification"]
            }
            
            if evaluation["label"] == "Supported":
                supported_points.append(point_info)
            elif evaluation["label"] == "Omitted":
                omitted_points.append(point_info)
            elif evaluation["label"] == "Contradicted":
                contradicted_points.append(point_info)
        
        # Calculate final score (normalized support rate)
        final_score = metrics["support_rate"] / 100.0
        
        # Prepare metric results
        metric_result = {
            "supported_points": supported_points,
            "omitted_points": omitted_points,
            "contradicted_points": contradicted_points,
            "support_rate": metrics["support_rate"],
            "omitted_rate": metrics["omitted_rate"],
            "contradicted_rate": metrics["contradicted_rate"],
            "total_points": metrics["total_points"],
        }
        
        # Create detailed summary
        summary = f"**Relevance Score:** {final_score:.4f} which is {metrics['supported_count']}/{metrics['total_points']} key points supported\n\n"
        summary += f"**Support Rate:** {metrics['support_rate']:.2f}%\n\n"
        summary += f"**Omitted Rate:** {metrics['omitted_rate']:.2f}%\n\n"
        summary += f"**Contradicted Rate:** {metrics['contradicted_rate']:.2f}%\n\n"
        summary += f"--------------------------------\n\n"
        
        # Add supported points section
        summary += f"**Supported Key Points ({len(supported_points)}):**\n\n--------------------------------\n\n"
        for i, point in enumerate(supported_points, 1):
            summary += f"**{i}. Key Point:** {point['key_point']}\n\n"
            summary += f"**Justification:** {point['justification']}\n\n"
            summary += f"--------------------------------\n\n"
        
        # Add omitted points section
        summary += f"**Omitted Key Points ({len(omitted_points)}):**\n\n--------------------------------\n\n"
        for i, point in enumerate(omitted_points, 1):
            summary += f"**{i}. Key Point:** {point['key_point']}\n\n"
            summary += f"**Justification:** {point['justification']}\n\n"
            summary += f"--------------------------------\n\n"
        
        # Add contradicted points section
        summary += f"**Contradicted Key Points ({len(contradicted_points)}):**\n\n--------------------------------\n\n"
        for i, point in enumerate(contradicted_points, 1):
            summary += f"**{i}. Key Point:** {point['key_point']}\n\n"
            summary += f"**Justification:** {point['justification']}\n\n"
            summary += f"--------------------------------\n\n"
        
        # Add final summary footer
        summary += f"--------------------------------\n\n"
        summary += f"Score: {final_score:.4f} which is {metrics['supported_count']}/{metrics['total_points']} key points supported"
        
        return {
            "score": final_score,
            "summary": summary,
            "metric_result": metric_result
        }
