import networkx as nx
import numpy as np
from typing import List, Dict, Any, Optional
from sklearn.metrics.pairwise import cosine_similarity
import time
import re


class UncertaintyCalculator:
    """
    Calculates uncertainty metrics for claims based on bipartite graph structure and semantic analysis.
    Implements methods from the Graph-based Uncertainty paper.
    """

    def __init__(self, client, use_claude=False):
        """
        Initialize uncertainty calculator.

        Args:
            client: OpenAI client or AWS Bedrock client for verbalized confidence assessment
            use_claude: Whether to use Claude 3.5 Sonnet via AWS Bedrock
        """
        if use_claude:
            self.bedrock_client = client  # When using Claude, client is actually bedrock_client
        else:
            self.client = client  # When using GPT, client is OpenAI client
        self.use_claude = use_claude

    def calculate_uncertainty(self, target_claim: Dict[str, Any],
                              all_claims: List[Dict[str, Any]],
                              bipartite_graph: nx.Graph) -> Dict[str, float]:
        """
        Calculate comprehensive uncertainty metrics for a target claim using the bipartite graph.
        Following the Graph-based Uncertainty paper methodology.

        Args:
            target_claim: The claim to analyze
            all_claims: All claims in the context
            bipartite_graph: Response-claim bipartite graph

        Returns:
            Dictionary of uncertainty metrics
        """
        # Find the claim node for the target claim
        target_node = self._find_claim_node(target_claim, bipartite_graph)

        metrics = {}

        # 1. Closeness Centrality (primary metric from paper)
        metrics['closeness_centrality'] = self._calculate_closeness_centrality_bipartite(
            target_node, bipartite_graph
        )

        # 2. Degree Centrality (baseline from paper)
        metrics['degree_centrality'] = self._calculate_degree_centrality_bipartite(
            target_node, bipartite_graph
        )

        # 3. Betweenness Centrality
        metrics['betweenness_centrality'] = self._calculate_betweenness_centrality_bipartite(
            target_node, bipartite_graph
        )

        # 4. Eigenvalue Centrality (CE from paper)
        metrics['eigenvalue_centrality'] = self._calculate_eigenvalue_centrality_bipartite(
            target_node, bipartite_graph
        )

        # 5. PageRank (CPR from paper)
        metrics['pagerank'] = self._calculate_pagerank_bipartite(
            target_node, bipartite_graph
        )

        # 6. Verbalized Confidence (from model)
        metrics['verbalized_confidence'] = self._get_verbalized_confidence(
            target_claim)

        return metrics

    def _find_claim_node(self, target_claim: Dict[str, Any], graph: nx.Graph) -> Optional[str]:
        """Find the claim node ID corresponding to a claim in the bipartite graph."""
        target_text = target_claim['claim']

        for node_id, node_data in graph.nodes(data=True):
            if (node_data.get('node_type') == 'claim' and
                    node_data.get('content') == target_text):
                return node_id

        return None

    def _calculate_closeness_centrality_bipartite(self, node_id: Optional[str], graph: nx.Graph) -> float:
        """
        Calculate closeness centrality for a claim node in the bipartite graph.
        Following the Graph-based Uncertainty paper approach.

        High centrality indicates the claim is well-connected to multiple responses,
        suggesting higher confidence/lower uncertainty.
        """
        if node_id is None or graph.number_of_nodes() <= 1:
            return 0.0

        try:
            # Calculate closeness centrality for the specific claim node
            centrality_dict = nx.closeness_centrality(graph)
            return centrality_dict.get(node_id, 0.0)
        except:
            return 0.0

    def _calculate_degree_centrality_bipartite(self, node_id: Optional[str], graph: nx.Graph) -> float:
        """
        Calculate degree centrality for a claim node in the bipartite graph.
        This corresponds to the self-consistency approach in the paper.
        """
        if node_id is None or graph.number_of_nodes() <= 1:
            return 0.0

        try:
            # Calculate degree centrality for the specific claim node
            centrality_dict = nx.degree_centrality(graph)
            return centrality_dict.get(node_id, 0.0)
        except:
            return 0.0

    def _calculate_betweenness_centrality_bipartite(self, node_id: Optional[str], graph: nx.Graph) -> float:
        """
        Calculate betweenness centrality for a claim node in the bipartite graph.
        """
        if node_id is None or graph.number_of_nodes() <= 1:
            return 0.0

        try:
            # Calculate betweenness centrality
            centrality_dict = nx.betweenness_centrality(graph)
            return centrality_dict.get(node_id, 0.0)
        except:
            return 0.0

    def _calculate_eigenvalue_centrality_bipartite(self, node_id: Optional[str], graph: nx.Graph) -> float:
        """
        Calculate eigenvalue centrality for a claim node in the bipartite graph.
        Following the Graph-based Uncertainty paper (CE metric).
        """
        if node_id is None or graph.number_of_nodes() <= 1:
            return 0.0

        try:
            # Calculate eigenvalue centrality for the specific claim node
            centrality_dict = nx.eigenvector_centrality(graph, max_iter=5000)
            return centrality_dict.get(node_id, 0.0)
        except:
            return 0.0

    def _calculate_pagerank_bipartite(self, node_id: Optional[str], graph: nx.Graph) -> float:
        """
        Calculate PageRank for a claim node in the bipartite graph.
        Following the Graph-based Uncertainty paper (CPR metric).
        """
        if node_id is None or graph.number_of_nodes() <= 1:
            return 0.0

        try:
            # Calculate PageRank for the specific claim node
            centrality_dict = nx.pagerank(graph, max_iter=5000)
            return centrality_dict.get(node_id, 0.0)
        except:
            return 0.0

    def _get_verbalized_confidence(self, claim: Dict[str, Any]) -> float:
        """
        Get verbalized confidence from the model about the claim.
        """
        # Use existing confidence if available
        if 'confidence' in claim:
            return claim['confidence']

        # Otherwise, query the model
        claim_text = claim['claim']

        prompt = f"""Rate your confidence in the following factual claim on a scale from 0.0 to 1.0, where:
- 1.0 = completely certain, well-established fact
- 0.5 = moderately confident, likely true but some uncertainty
- 0.0 = very uncertain, difficult to verify

Claim: "{claim_text}"

Provide only a numerical confidence score (0.0 to 1.0):"""

        try:
            if self.use_claude:
                # Use Claude 3.5 Sonnet via AWS Bedrock
                import json
                request_body = {
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": 10,
                    "temperature": 0.1,
                    "messages": [{"role": "user", "content": prompt}]
                }

                response = self.bedrock_client.invoke_model(
                    modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
                    body=json.dumps(request_body)
                )

                # Parse Claude response
                response_body = json.loads(response['body'].read())
                confidence_text = response_body['content'][0]['text'].strip()
            else:
                # Use OpenAI GPT-4o
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=10,
                    temperature=0.1,
                )
                confidence_text = response.choices[0].message.content.strip()

            # Extract numerical value
            confidence = self._parse_confidence_score(confidence_text)
            return confidence

        except Exception as e:
            print(f"Error getting verbalized confidence: {e}")
            return 0.5  # Default confidence

    def _parse_confidence_score(self, text: str) -> float:
        """Parse confidence score from model response."""
        # Look for decimal numbers
        import re
        numbers = re.findall(r'\b[01]?\.\d+\b|\b[01]\b', text)

        if numbers:
            try:
                score = float(numbers[0])
                return max(0.0, min(1.0, score))  # Clamp to [0, 1]
            except ValueError:
                pass

        # Look for percentages
        percentages = re.findall(r'(\d+)%', text)
        if percentages:
            try:
                score = float(percentages[0]) / 100.0
                return max(0.0, min(1.0, score))
            except ValueError:
                pass

        return 0.5  # Default

    def calculate_claim_support(self, target_claim: Dict[str, Any],
                                bipartite_graph: nx.Graph) -> Dict[str, Any]:
        """
        Calculate detailed support information for a claim from the bipartite graph.
        """
        target_node = self._find_claim_node(target_claim, bipartite_graph)

        if target_node is None:
            return {
                'supporting_responses': [],
                'total_responses': 0,
                'support_ratio': 0.0,
                'entailment_scores': []
            }

        # Get connected response nodes
        supporting_responses = []
        entailment_scores = []

        for neighbor in bipartite_graph.neighbors(target_node):
            neighbor_data = bipartite_graph.nodes[neighbor]
            if neighbor_data.get('node_type') == 'response':
                edge_data = bipartite_graph.get_edge_data(
                    target_node, neighbor)
                entailment_score = edge_data.get('entailment_score', 0.0)

                supporting_responses.append({
                    'response_id': neighbor,
                    'response_content': neighbor_data.get('content', ''),
                    'entailment_score': entailment_score
                })
                entailment_scores.append(entailment_score)

        # Total number of responses
        total_responses = len([node for node, data in bipartite_graph.nodes(data=True)
                               if data.get('node_type') == 'response'])

        return {
            'supporting_responses': supporting_responses,
            'total_responses': total_responses,
            'support_ratio': len(supporting_responses) / max(1, total_responses),
            'entailment_scores': entailment_scores,
            'avg_entailment_score': np.mean(entailment_scores) if entailment_scores else 0.0
        }

    def analyze_claim_relationships_bipartite(self, bipartite_graph: nx.Graph) -> Dict[str, Any]:
        """
        Analyze relationships between claims based on shared response support.
        """
        claim_nodes = [node for node, data in bipartite_graph.nodes(data=True)
                       if data.get('node_type') == 'claim']

        if len(claim_nodes) < 2:
            return {'claim_similarities': {}, 'claim_clusters': []}

        # Calculate claim-claim similarities based on shared response connections
        claim_similarities = {}

        for claim1 in claim_nodes:
            claim_similarities[claim1] = {}
            responses1 = set(neighbor for neighbor in bipartite_graph.neighbors(claim1)
                             if bipartite_graph.nodes[neighbor].get('node_type') == 'response')

            for claim2 in claim_nodes:
                if claim1 != claim2:
                    responses2 = set(neighbor for neighbor in bipartite_graph.neighbors(claim2)
                                     if bipartite_graph.nodes[neighbor].get('node_type') == 'response')

                    # Jaccard similarity
                    intersection = len(responses1.intersection(responses2))
                    union = len(responses1.union(responses2))
                    similarity = intersection / union if union > 0 else 0.0

                    claim_similarities[claim1][claim2] = similarity

        # Simple clustering
        clusters = []
        processed = set()

        for claim in claim_nodes:
            if claim in processed:
                continue

            cluster = [claim]
            processed.add(claim)

            # Find similar claims
            for other_claim in claim_nodes:
                if (other_claim not in processed and
                        claim_similarities[claim].get(other_claim, 0) > 0.3):
                    cluster.append(other_claim)
                    processed.add(other_claim)

            clusters.append(cluster)

        return {
            'claim_similarities': claim_similarities,
            'claim_clusters': clusters
        }
