import json
import re
from typing import List, Dict, Any
import time
import os
from dotenv import load_dotenv
from openai import OpenAI
import boto3


class ClaimDecomposer:
    """
    Decomposes long-form responses into smaller, standalone claims.
    Based on the breakdown approach from the Graph-based Uncertainty paper.
    Supports both OpenAI GPT-4o and AWS Bedrock Claude 3.5 Sonnet.
    """

    def __init__(self, use_claude=False):
        """
        Initialize ClaimDecomposer.

        Args:
            use_claude: Whether to use Claude 3.5 Sonnet via AWS Bedrock (default: False, uses GPT-4o)
        """
        # Load environment variables
        load_dotenv()

        self.use_claude = use_claude

        if use_claude:
            # Initialize AWS Bedrock client for Claude 3.5 Sonnet
            self.bedrock_client = boto3.client(
                'bedrock-runtime',
                region_name=os.getenv('AWS_REGION', 'us-east-1'),
            )
        else:
            # Initialize OpenAI client
            self.client = OpenAI(
                api_key=os.getenv("OPENAI_API_KEY")
            )

        self.breakdown_prompt = self._get_breakdown_prompt()

    def _get_breakdown_prompt(self):
        """Get the prompt for breaking down responses into claims (following original paper)."""
        return """Please deconstruct the following paragraph into the smallest possible standalone self-contained facts without semantic repetition, and return the output as a jsonl, where each line is {{claim:[CLAIM], gpt-confidence:[CONF]}}.

CRITICAL: Extract ONLY the 8-12 MOST IMPORTANT claims. Be extremely selective. Focus ONLY on:
- Direct answers to the specific question asked
- Specific numerical values, percentages, or measurements
- Key causal relationships (A causes B)
- Critical scientific conclusions

STRICTLY AVOID:
- General background information
- Basic definitions (what SO2 is, what SSP scenarios are, etc.)
- Procedural explanations
- Location descriptions
- Any claim that doesn't directly address the question

Each claim must be essential to answering the question. If unsure whether to include a claim, DON'T include it.

The confidence score [CONF] should represent your confidence in the claim, where a 1 is obvious facts and results like 'The earth is round' and '1+1=2'. A 0 is for claims that are very obscure or difficult for anyone to know, like the birthdays of non-notable people. 

The input is:
'{text_generation}'"""

    def decompose_response(self, response_text: str) -> List[Dict[str, Any]]:
        """
        Decompose a response into individual claims with confidence scores.

        Args:
            response_text: The text response to decompose

        Returns:
            List of dictionaries with 'claim' and 'confidence' keys
        """
        if not response_text or not response_text.strip():
            return []

        breakdown_prompt = self.breakdown_prompt.format(
            text_generation=response_text)

        max_retries = 3
        for attempt in range(max_retries):
            try:
                if self.use_claude:
                    # Use Claude 3.5 Sonnet via AWS Bedrock
                    request_body = {
                        "anthropic_version": "bedrock-2023-05-31",
                        "max_tokens": 800,
                        "temperature": 0.1,
                        "messages": [{"role": "user", "content": breakdown_prompt}]
                    }

                    response = self.bedrock_client.invoke_model(
                        modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
                        body=json.dumps(request_body)
                    )

                    # Parse Claude response
                    response_body = json.loads(response['body'].read())
                    response_content = response_body['content'][0]['text'].strip(
                    )
                else:
                    # Use GPT-4o for claim decomposition
                    response = self.client.chat.completions.create(
                        model="gpt-4o",
                        messages=[
                            {"role": "user", "content": breakdown_prompt}],
                        max_tokens=800,  # Limit to control output length
                        temperature=0.1,  # Low temperature for consistent extraction
                    )
                    response_content = response.choices[0].message.content

                claims = self._parse_claims_response(response_content)

                if claims:  # If we got valid claims, return them
                    cleaned_claims = self._clean_claims(claims)
                    # Strict filtering to ensure we stay within 8-15 range
                    if len(cleaned_claims) > 15:
                        # Sort by confidence and take top 15
                        cleaned_claims.sort(
                            key=lambda x: x['confidence'], reverse=True)
                        cleaned_claims = cleaned_claims[:15]
                    return cleaned_claims

            except Exception as e:
                print(
                    f"Error in claim decomposition attempt {attempt + 1}: {e}")
                if attempt < max_retries - 1:
                    time.sleep(2)  # Wait before retry

        # Fallback: create a single claim from the entire response
        print("Fallback: Creating single claim from response")
        return [{
            'claim': response_text.strip()[:500] + "..." if len(response_text) > 500 else response_text.strip(),
            'confidence': 0.5  # Default confidence
        }]

    def _parse_claims_response(self, response_text: str) -> List[Dict[str, Any]]:
        """Parse the model's response to extract claims following the original implementation."""
        # Extract output from completion (following original implementation)
        output = self._extract_output_from_completion(response_text)

        try:
            # Try to parse as JSONL (one JSON object per line) - original approach
            claims = []
            for line in output.splitlines():
                line = line.strip()
                if line:
                    try:
                        claim = json.loads(line)
                        claims.append(claim)
                    except json.JSONDecodeError:
                        continue
            return claims
        except Exception:
            # Fallback to line-by-line parsing
            return self._parse_json_lines(output)

    def _extract_output_from_completion(self, response_text: str) -> str:
        """Extract output from completion following the original implementation."""
        # Handle OpenAI response format (direct text)
        output = response_text

        # Apply original text processing
        output = output.replace("```jsonl\n", "").replace(
            "```json\n", "").replace("```", "")
        output = output.replace('claim:', '"claim":').replace(
            'gpt-confidence:', '"gpt-confidence":')

        # Extract content between first { and last }
        if output.find('{') != -1 and output.rfind('}') != -1:
            output = output[output.find('{'):output.rfind('}') + 1]

        return output

    def _parse_json_lines(self, jsonl_string: str) -> List[Dict[str, Any]]:
        """Parse JSONL string following the original implementation."""
        subclaims = []
        for line in jsonl_string.split("\n"):
            if not line.strip():
                continue
            try:
                subclaim = json.loads(line)
                subclaims.append(subclaim)
            except json.JSONDecodeError as e:
                # Try to extract claim manually
                if '{' in line and '}' in line:
                    try:
                        # Extract JSON-like content
                        json_match = re.search(r'\{.*?\}', line)
                        if json_match:
                            subclaim = json.loads(json_match.group(0))
                            subclaims.append(subclaim)
                    except:
                        continue
        return subclaims

    def _clean_claims(self, claims: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Clean and validate the extracted claims following the original implementation."""
        cleaned_claims = []

        for claim_dict in claims:
            if not isinstance(claim_dict, dict):
                continue

            # Handle case where claim might be a list (from original implementation)
            if 'claim' in claim_dict and isinstance(claim_dict['claim'], list):
                claim_dict['claim'] = claim_dict['claim'][0] if claim_dict['claim'] else ""

            # Extract claim text
            claim_text = claim_dict.get('claim', '')
            if not claim_text:
                continue

            claim_text = str(claim_text).strip()

            # Skip very short claims (but be more lenient than before)
            if len(claim_text) < 10:
                continue

            # Extract confidence (following original implementation)
            confidence = claim_dict.get('gpt-confidence', 0.5)

            # Try alternative confidence field names as fallback
            if confidence is None:
                confidence = claim_dict.get('confidence', 0.5)

            # Ensure confidence is a float between 0 and 1
            try:
                confidence = float(confidence)
                confidence = max(0.0, min(1.0, confidence))
            except (ValueError, TypeError):
                confidence = 0.5

            cleaned_claims.append({
                'claim': claim_text,
                'confidence': confidence  # Keep as 'confidence' for internal use
            })

        return cleaned_claims
