import networkx as nx
import numpy as np
from typing import List, Dict, Any, Tuple
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import re
import time
import json
import os
from dotenv import load_dotenv
from openai import OpenAI

class GraphBuilder:
    """
    Builds response-claim bipartite graphs based on semantic entailment.
    Follows the graph construction approach from the Graph-based Uncertainty paper.
    """
    
    def __init__(self, entailment_threshold: float = 0.5):
        """
        Initialize graph builder.
        
        Args:
            entailment_threshold: Minimum confidence threshold for entailment relation
        """
        self.entailment_threshold = entailment_threshold
        
        # Initialize OpenAI client for entailment checking
        load_dotenv()
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        
        try:
            self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
        except Exception as e:
            print(f"Warning: Failed to initialize sentence transformer in GraphBuilder: {e}")
            self.sentence_model = None
        
    def build_bipartite_graph(self, responses: List[str], claims: List[Dict[str, Any]]) -> nx.Graph:
        """
        Build a response-claim bipartite graph based on entailment relationships.
        Following the Graph-based Uncertainty paper methodology.
        
        Args:
            responses: List of response texts
            claims: List of claim dictionaries with 'claim' and 'confidence' keys
            
        Returns:
            NetworkX bipartite graph with responses and claims as nodes
        """
        if not responses or not claims:
            return nx.Graph()
            
        # Create bipartite graph
        G = nx.Graph()
        
        # Add response nodes (left side of bipartite graph)
        for i, response in enumerate(responses):
            response_node_id = f"response_{i}"
            G.add_node(response_node_id, 
                      node_type='response',
                      content=response,
                      bipartite=0)  # Left side
        
        # Add claim nodes (right side of bipartite graph)
        for j, claim in enumerate(claims):
            claim_node_id = f"claim_{j}"
            G.add_node(claim_node_id,
                      node_type='claim', 
                      content=claim['claim'],
                      confidence=claim.get('confidence', 0.5),
                      bipartite=1)  # Right side
        
        # Add edges based on entailment relationships
        print("  Checking entailment relationships...")
        for i, response in enumerate(responses):
            response_node_id = f"response_{i}"
            for j, claim in enumerate(claims):
                claim_node_id = f"claim_{j}"
                
                # Check if response entails claim
                entailment_score = self._check_entailment(response, claim['claim'])
                
                # Add edge if entailment is strong enough
                if entailment_score >= self.entailment_threshold:
                    G.add_edge(response_node_id, claim_node_id, 
                             entailment_score=entailment_score)
        
        print(f"  Built bipartite graph: {len(responses)} responses, {len(claims)} claims, {G.number_of_edges()} entailment edges")
        return G

    def build_bipartite_graph_with_mapping(self, responses: List[str], claims: List[Dict[str, Any]],
                                         claim_response_mapping: Dict[int, List[int]]) -> nx.Graph:
        """
        Build a response-claim bipartite graph using pre-computed entailment mapping.
        Following the paper's approach where edges are based on LLM entailment from claim merging.

        Args:
            responses: List of response texts
            claims: List of claim dictionaries
            claim_response_mapping: Dict mapping claim indices to response indices that entail them

        Returns:
            NetworkX bipartite graph with responses and claims as nodes
        """
        if not responses or not claims:
            return nx.Graph()

        # Create bipartite graph
        G = nx.Graph()

        # Add response nodes (left side of bipartite graph)
        for i, response in enumerate(responses):
            response_node_id = f"response_{i}"
            G.add_node(response_node_id,
                      node_type='response',
                      content=response,
                      bipartite=0)  # Left side

        # Add claim nodes (right side of bipartite graph)
        for j, claim in enumerate(claims):
            claim_node_id = f"claim_{j}"
            G.add_node(claim_node_id,
                      node_type='claim',
                      content=claim['claim'],
                      confidence=claim.get('confidence', 0.5),
                      bipartite=1)  # Right side

        # Add edges based on the entailment mapping from claim merging
        print("  Adding edges based on LLM entailment mapping...")
        edge_count = 0

        for claim_idx, response_indices in claim_response_mapping.items():
            claim_node_id = f"claim_{claim_idx}"

            for response_idx in response_indices:
                if response_idx < len(responses):  # Validate response index
                    response_node_id = f"response_{response_idx}"

                    # Add edge with entailment score of 1.0 (since LLM determined entailment)
                    G.add_edge(response_node_id, claim_node_id, entailment_score=1.0)
                    edge_count += 1

        print(f"  Built bipartite graph: {len(responses)} responses, {len(claims)} claims, {edge_count} entailment edges")
        return G
    
    def _check_entailment(self, response: str, claim: str) -> float:
        """
        Check if a response entails a claim using GPT-4o.
        Following the entailment checking approach from the paper.
        
        Args:
            response: The response text
            claim: The claim text
            
        Returns:
            Entailment score between 0 and 1
        """
        # Use the exact prompt format from the paper's Appendix F.3
        prompt = f"""Context: {response}
Claim: {claim}
Is the claim supported by the context above?
Answer Yes or No:"""
        
        try:
            api_response = self.client.chat.completions.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}],
                max_tokens=10,
                temperature=0.1,
            )
            
            answer = api_response.choices[0].message.content.strip().lower()
            
            # Convert Yes/No to score
            if "yes" in answer:
                return 1.0
            elif "no" in answer:
                return 0.0
            else:
                return 0.5  # Uncertain cases
                
        except Exception as e:
            print(f"Error checking entailment: {e}")
            return 0.5  # Default uncertain score
    
    def build_similarity_graph(self, claims: List[Dict[str, Any]]) -> nx.Graph:
        """
        Build a claim similarity graph for comparison purposes.
        This is the old method - kept for backward compatibility.
        """
        if not claims:
            return nx.Graph()
            
        # Extract claim texts
        claim_texts = [claim['claim'] for claim in claims]
        
        # Compute embeddings for all claims
        embeddings = self._compute_embeddings(claim_texts)
        
        # Create graph
        G = nx.Graph()
        
        # Add nodes (claims) with attributes
        for i, claim in enumerate(claims):
            G.add_node(i, 
                      claim=claim['claim'],
                      confidence=claim['confidence'],
                      embedding=embeddings[i])
        
        # Add edges based on similarity
        similarity_matrix = cosine_similarity(embeddings)
        similarity_threshold = 0.5
        
        for i in range(len(claims)):
            for j in range(i + 1, len(claims)):
                similarity = similarity_matrix[i][j]
                
                # Add edge if similarity exceeds threshold
                if similarity >= similarity_threshold:
                    G.add_edge(i, j, similarity=similarity)
        
        return G
    
    def _compute_embeddings(self, texts: List[str]) -> np.ndarray:
        """Compute sentence embeddings for claim texts."""
        if self.sentence_model is None:
            # Fallback: use simple word overlap as similarity proxy
            return np.random.rand(len(texts), 384)  # Random embeddings as last resort
            
        # Preprocess texts
        processed_texts = [self._preprocess_text(text) for text in texts]
        
        # Compute embeddings
        embeddings = self.sentence_model.encode(processed_texts)
        
        return embeddings
    
    def _preprocess_text(self, text: str) -> str:
        """Preprocess text for better embedding computation."""
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        
        # Convert to lowercase for consistency
        text = text.lower()
        
        # Remove special characters that might interfere
        text = re.sub(r'[^\w\s\-.,!?]', '', text)
        
        return text
    
    def get_claim_nodes(self, G: nx.Graph) -> List[str]:
        """Get all claim node IDs from bipartite graph."""
        return [node for node, data in G.nodes(data=True) 
                if data.get('node_type') == 'claim']
    
    def get_response_nodes(self, G: nx.Graph) -> List[str]:
        """Get all response node IDs from bipartite graph.""" 
        return [node for node, data in G.nodes(data=True)
                if data.get('node_type') == 'response']
    
    def find_claim_clusters(self, G: nx.Graph) -> List[List[str]]:
        """
        Find clusters of claims that are connected to similar sets of responses.
        
        Args:
            G: Bipartite graph
            
        Returns:
            List of claim clusters (each cluster is a list of claim node IDs)
        """
        if G.number_of_nodes() == 0:
            return []
        
        claim_nodes = self.get_claim_nodes(G)
        if not claim_nodes:
            return []
            
        # Build claim-claim similarity based on shared response connections
        claim_similarity = {}
        
        for claim1 in claim_nodes:
            claim_similarity[claim1] = {}
            responses1 = set(G.neighbors(claim1))
            
            for claim2 in claim_nodes:
                if claim1 != claim2:
                    responses2 = set(G.neighbors(claim2))
                    # Jaccard similarity of connected responses
                    intersection = len(responses1.intersection(responses2))
                    union = len(responses1.union(responses2))
                    similarity = intersection / union if union > 0 else 0.0
                    claim_similarity[claim1][claim2] = similarity
        
        # Simple clustering based on similarity threshold
        clusters = []
        processed_claims = set()
        
        for claim in claim_nodes:
            if claim in processed_claims:
                continue
                
            cluster = [claim]
            processed_claims.add(claim)
            
            # Find similar claims
            for other_claim in claim_nodes:
                if (other_claim not in processed_claims and 
                    claim_similarity[claim].get(other_claim, 0) > 0.3):
                    cluster.append(other_claim)
                    processed_claims.add(other_claim)
            
            clusters.append(cluster)
        
        # Sort clusters by size (largest first)
        clusters.sort(key=len, reverse=True)
        
        return clusters
    
    def get_graph_statistics(self, G: nx.Graph) -> Dict[str, Any]:
        """
        Compute various graph statistics for bipartite graph.
        
        Args:
            G: Bipartite graph
            
        Returns:
            Dictionary of graph statistics
        """
        if G.number_of_nodes() == 0:
            return {
                'num_response_nodes': 0,
                'num_claim_nodes': 0,
                'num_edges': 0,
                'density': 0.0,
                'avg_response_degree': 0.0,
                'avg_claim_degree': 0.0,
                'num_components': 0,
                'largest_component_size': 0
            }
        
        response_nodes = self.get_response_nodes(G)
        claim_nodes = self.get_claim_nodes(G)
        
        stats = {
            'num_response_nodes': len(response_nodes),
            'num_claim_nodes': len(claim_nodes),
            'num_edges': G.number_of_edges(),
            'density': nx.density(G),
            'num_components': nx.number_connected_components(G),
        }
        
        # Average degrees
        if response_nodes:
            response_degrees = [G.degree(node) for node in response_nodes]
            stats['avg_response_degree'] = np.mean(response_degrees)
        else:
            stats['avg_response_degree'] = 0.0
            
        if claim_nodes:
            claim_degrees = [G.degree(node) for node in claim_nodes]
            stats['avg_claim_degree'] = np.mean(claim_degrees)
        else:
            stats['avg_claim_degree'] = 0.0
        
        # Largest component size
        if stats['num_components'] > 0:
            components = list(nx.connected_components(G))
            stats['largest_component_size'] = max(len(comp) for comp in components)
        else:
            stats['largest_component_size'] = 0
            
        return stats
    
    def visualize_graph_info(self, G: nx.Graph) -> str:
        """
        Generate a text-based summary of the bipartite graph structure.
        
        Args:
            G: Bipartite graph
            
        Returns:
            String description of the graph
        """
        stats = self.get_graph_statistics(G)
        clusters = self.find_claim_clusters(G)
        
        info = f"""Bipartite Graph Structure Summary:
- Response nodes: {stats['num_response_nodes']}
- Claim nodes: {stats['num_claim_nodes']} 
- Entailment edges: {stats['num_edges']}
- Density: {stats['density']:.3f}
- Avg response degree: {stats['avg_response_degree']:.1f}
- Avg claim degree: {stats['avg_claim_degree']:.1f}
- Connected components: {stats['num_components']}
- Largest component: {stats['largest_component_size']} nodes

Claim Clusters (based on shared response connections):"""
        
        for i, cluster in enumerate(clusters[:5]):  # Show top 5 clusters
            info += f"\n  Cluster {i+1}: {len(cluster)} claims"
            if len(cluster) <= 3:
                # Show claim snippets for small clusters
                for node_id in cluster:
                    if node_id in G.nodes:
                        claim_text = G.nodes[node_id].get('content', '')[:50] + "..."
                        info += f"\n    - {claim_text}"
        
        if len(clusters) > 5:
            info += f"\n  ... and {len(clusters) - 5} more clusters"
            
        return info 