import json
import re
import time
from typing import List, Dict, Any
import requests
from bs4 import BeautifulSoup
import logging
import tqdm

from drbench.metrics.base import DrBenchMetric
from drbench.agents.utils import prompt_llm

# Configure logging
logger = logging.getLogger(__name__)

class DRGymFactuality(DrBenchMetric):
    def __init__(self, model: str):
        """
        Initialize the DrGym Factuality metric.

        Args:
            model: The name of the model to use for scoring
        """
        super().__init__(name="drgym_factuality", model=model)
        self.model = model

    def _get_support_score(self, support_level: str) -> float:
        """Convert support level to numeric score"""
        scores = {
            "no_support": 0.0,
            "partial_support": 0.5,
            "full_support": 1.0,
        }
        return scores.get(support_level, 0.0)

    def crawl_url(self, url: str) -> str:
        """Crawl a single URL and return its text content"""
        try:
            headers = {
                "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"
            }
            
            response = requests.get(url, headers=headers, timeout=10)
            if response.status_code == 200:
                soup = BeautifulSoup(response.text, 'html.parser')
                
                # Remove script and style elements
                for script in soup(["script", "style"]):
                    script.decompose()
                
                # Extract text content
                text = soup.get_text(separator=' ', strip=True)
                return text[:5000] if text else f"**Error: No content extracted from {url}**"
            else:
                return f"**Error for {url}: HTTP {response.status_code}**"
                
        except Exception as e:
            return f"**Error for {url}: {str(e)}**"

    def crawl_urls(self, urls: List[str]) -> List[str]:
        """Crawl multiple URLs"""
        results = []
        for url in urls:
            result = self.crawl_url(url)
            results.append(result)
            time.sleep(0.1)  # Small delay to be respectful
        return results

    def create_prompt_extractor(self, answer: str) -> str:
        """Create prompt for extracting claims and citations from report"""
        return f"""You are an information extraction expert.

Given a structured report containing claims and their supporting sources (usually in the form of inline hyperlinks or referenced URLs), extract all distinct factual or argumentative claims that are explicitly supported by a specific reference in the text.

Return a JSON object like this:
{{
  "claims": [
    {{
      "claim_id": 1,
      "claim": "<claim_1>",
      "sources": ["<url_1>","<url_2>"]
    }},
    {{
      "claim_id": 2,
      "claim": "<claim_2>",
      "sources": ["<url_1>"]
    }}
  ]
}}

Where:
- The root is "claims", which contains a list of json claim objects.
- Each claim json object has: 
    - claim_id: an identifier (sequential integer starting from 1).
    - claim: a concise but complete sentence restating the claim.
    - sources: a list of URLs, which are the sources that explicitly support the claim (**IMPORTANT**: must be taken directly from the report, can be one or more).

**IMPORTANT**: Only include claims that are directly and explicitly supported by a source in the report. Do not include general summaries, opinions, or claims that lack citation.

Process the full report carefully to ensure all source-supported claims are included and accurately captured.

Now extract the claims from the report below:

{answer}

Return the JSON object, nothing else.
"""

    def create_prompt_citation_checker(self, claim: str, docs: List[str]) -> str:
        """Create prompt for checking citation quality"""
        citations_text = "\n\n".join(f"[{i+1}] {doc}" for i, doc in enumerate(docs))
        return f"""In this task, you will evaluate whether each statement is
        supported by its corresponding citations. Note that the system
        responses may appear very fluent and well-formed, but contain
        slight inaccuracies that are not easy to discern at first glance.
        Pay close attention to the text.

        You will be provided with a statement and its corresponding
        citations. It may be helpful to ask yourself whether it is
        accurate to say "according to the citation" with a
        statement following this phrase. Be sure to check all of the
        information in the statement. You will be given three options:

        - Full Support: All of the information in the statement is
        supported in the citations.

        - Partial Support: Some parts of the information are supported in
        the citations, but other parts are missing from the citations.

        - No Support: These citations does not support any part of the
        statement.

        Please provide your response based on the information in the
        citations. If you are unsure, use your best judgment. Respond as
        either ``full_support'', ``partial_support'', or ``no_support''
        with no additional information. 
        You should also provide a very brief justification to your assessment.

        Statement: {claim}

        Citations: {citations_text}

        Return your response as JSON in this format:
        {{
            "support": "full_support|partial_support|no_support",
            "justification": "your brief explanation"
        }}
    """

    def extract_claims_and_urls(self, answer: str) -> List[Dict]:
        """Extract claims and their supporting URLs from report text"""
        prompt = self.create_prompt_extractor(answer)
        
        try:
            response = prompt_llm(prompt, self.model)
            
            # Parse JSON response
            result = json.loads(response)
            return result.get("claims", [])
            
        except Exception as e:
            logger.error(f"Error extracting claims: {e}")
            return []

    def check_citation_quality(self, claim: str, docs: List[str]) -> dict:
        """Check how well citations support a given claim"""
        prompt = self.create_prompt_citation_checker(claim, docs)
        
        try:
            response = prompt_llm(prompt, self.model, temperature=0)
            
            # Parse JSON response
            result = json.loads(response)
            return result
            
        except Exception as e:
            logger.error(f"Error checking citation quality: {e}")
            return {"support": "no_support", "justification": f"Error: {str(e)}"}

    def compute(self, report_dict: Dict[str, Any], task_data=None, eval_data=None) -> dict:
        """
        Compute DrGym factuality scores using citation quality evaluation.

        Args:
            report_dict: Dictionary containing 'report_text' and 'report_insights'
            task_data: Task-specific data (unused)
            eval_data: Evaluation data (unused)

        Returns:
            Dict: Standardized result with citation-based factuality scores
        """
        report_text = report_dict.get("report_text", "")
        
        # Check if report contains URLs
        url_pattern = r'https?://\S+|www\.\S+'
        if not re.search(url_pattern, report_text):
            return {
                "score": 0.0,
                "summary": "No URLs found in text.",
                "metric_result": {
                    "factual_claims": [],
                    "unfactual_claims": [],
                    "factuality_percentage": 0.0,
                    "total_claims": 0,
                }
            }

        # Extract claims and their citations
        claims_data = self.extract_claims_and_urls(report_text)
        
        if not claims_data:
            return {
                "score": 0.0,
                "summary": "No claims with citations found.",
                "metric_result": {
                    "factual_claims": [],
                    "unfactual_claims": [],
                    "factuality_percentage": 0.0,
                    "total_claims": 0,
                }
            }

        # Evaluate each claim
        scores = {}
        factual_claims = []
        unfactual_claims = []
        
        for claim in tqdm.tqdm(claims_data, desc="Evaluating claims", leave=False):
            claim_text = claim["claim"]
            urls = claim["sources"]
            claim_id = claim["claim_id"]
            
            if not urls:
                continue
            
            # Crawl URLs to get source content
            docs = self.crawl_urls(urls)
            if not docs:
                continue
            
            # Clean URLs from documents
            clean_docs = [re.sub(url_pattern, '', doc) for doc in docs if doc.strip()]
            
            try:
                # Check citation quality
                quality_result = self.check_citation_quality(claim_text, clean_docs)
                support_level = quality_result["support"]
                justification = quality_result["justification"]
                
                if support_level:
                    score = self._get_support_score(support_level)
                    scores[f"claim_{claim_id}"] = {
                        "claim": claim_text,
                        "urls": urls,
                        "score": score,
                        "support_level": support_level,
                        "justification": justification,
                    }
                    
                    # Categorize claims
                    if support_level == "full_support":
                        factual_claims.append(claim_text)
                    else:
                        unfactual_claims.append(claim_text)
            
            except Exception as e:
                scores[f"claim_{claim_id}"] = {
                    "claim": claim_text,
                    "urls": urls,
                    "score": 0.0,
                    "support_level": "no_support",
                    "justification": f"Error: {str(e)}",
                }
                unfactual_claims.append(claim_text)

        # Calculate final score
        final_score = sum(s["score"] for s in scores.values()) / len(scores) if scores else 0.0
        factuality_percentage = final_score * 100
        total_claims = len(scores)

        # Prepare metric results
        metric_result = {
            "factual_claims": factual_claims,
            "unfactual_claims": unfactual_claims,
            "factuality_percentage": factuality_percentage,
            "total_claims": total_claims,
        }

        # Create detailed summary
        summary = f"**Factuality Score:** {final_score:.4f} which is {len(factual_claims)}/{total_claims} claims\n\n"
        summary += f"--------------------------------\n\n"
        
        # Add factual claims section
        summary += f"**Factual Claims:**\n\n--------------------------------\n\n"
        for claim_id, claim_info in scores.items():
            if claim_info["support_level"] == "full_support":
                summary += f"**Claim:** {claim_info['claim']}\n\n"
                summary += f"**URLs:** {', '.join(claim_info['urls'])}\n\n"
                summary += f"**Support Level:** {claim_info['support_level']}\n\n"
                summary += f"**Justification:** {claim_info['justification']}\n\n"
                summary += f"--------------------------------\n\n"
        
        # Add unfactual claims section
        summary += f"**Unfactual Claims:**\n\n--------------------------------\n\n"
        for claim_id, claim_info in scores.items():
            if claim_info["support_level"] != "full_support":
                summary += f"**Claim:** {claim_info['claim']}\n\n"
                summary += f"**URLs:** {', '.join(claim_info['urls'])}\n\n"
                summary += f"**Support Level:** {claim_info['support_level']}\n\n"
                summary += f"**Justification:** {claim_info['justification']}\n\n"
                summary += f"--------------------------------\n\n"

        # Add final summary footer
        summary += f"--------------------------------\n\n"
        summary += f"Score: {final_score:.4f} which is {len(factual_claims)}/{total_claims} claims"

        return {
            "score": final_score,
            "summary": summary,
            "metric_result": metric_result
        }