import os
import json
import time
import traceback
import numpy as np
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
import networkx as nx
from typing import Dict, List, Any, Tuple
import re
from sklearn.metrics.pairwise import cosine_similarity
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

# Import uncertainty tools
from uncertainty.claim_decomposer import ClaimDecomposer
from uncertainty.uncertainty_calculator import UncertaintyCalculator
from uncertainty.graph_builder import GraphBuilder


class UncertaintyAnalysis:
    def __init__(self, num_generations: int = 3, threshold_percentile: int = 40):
        """
        Initialize uncertainty analysis system following the paper's approach.

        Args:
            num_generations: Number of API calls per question (default: 3)
            threshold_percentile: Percentile for claim filtering threshold
        """
        self.num_generations = num_generations
        self.threshold_percentile = threshold_percentile

        # Load environment and configure API
        dotenv_path = os.path.join(os.path.dirname(__file__), '..', '..', '.env')
        load_dotenv(dotenv_path=dotenv_path)

        # Initialize OpenAI client for GPT-4o
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

        # Initialize components
        self.claim_decomposer = ClaimDecomposer()  # Already uses GPT-4o
        self.uncertainty_calculator = UncertaintyCalculator(
            self.client)  # Updated to use OpenAI
        self.graph_builder = GraphBuilder()

        # Thread lock for safe file writing
        self.file_lock = threading.Lock()


        print(
            f"Initialized UncertaintyAnalysis with {num_generations} generations per question")

    def generate_multiple_responses(self, question: str) -> List[str]:
        """Generate multiple responses for a given question using GPT-4o."""
        responses = []

        prompt = f"""Please provide a comprehensive answer to the following epidemiology-related question.
        Focus on providing factual, well-reasoned information based on epidemiology and public health science, 
        also try predict quantitative information of the peak's severity, peak's timing, and the initial speed of the outbreak in your answer to justify your result.

        Question: {question}

        Answer:"""

        for i in range(self.num_generations):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=1000,
                    temperature=0.7,  # Some temperature for diversity
                    top_p=0.9,
                )
                responses.append(response.choices[0].message.content.strip())
                print(f"  Generated response {i+1}/{self.num_generations}")
                time.sleep(1)  # Rate limiting

            except Exception as e:
                print(f"Error generating response {i+1}: {e}")
                responses.append("")
                time.sleep(5)  # Longer wait on error

        return responses

    def match_claims_across_generations(self, all_claims: List[List[Dict]], responses: List[str]) -> tuple[List[Dict], Dict]:
        """
        Match and merge claims across different generations following the paper's approach.
        Uses LLM-based entailment checking instead of semantic similarity.

        Returns:
            merged_claims: List of unique claims
            claim_response_mapping: Dict mapping claim indices to response indices that entail them
        """
        if not all_claims:
            return [], {}

        # Initialize mapping to track which responses entail which claims
        claim_response_mapping = {}

        # Step 2 from paper: Sequential claim merging using LLM entailment
        # Start with claims from the first generation: C^(1) = C_r1
        merged_claims = all_claims[0].copy()

        # Initialize mapping for first generation claims
        for claim_idx in range(len(merged_claims)):
            claim_response_mapping[claim_idx] = [0]  # Response 0 entails these claims

        # Process subsequent generations: C^(i) = M(C^(i-1), C_ri) for i ∈ {2, ..., N}
        for gen_idx in range(1, len(all_claims)):
            current_claims = all_claims[gen_idx]

            # Merge current claims into the existing set using LLM entailment
            merged_claims, claim_response_mapping = self._merge_claim_sets_llm(
                merged_claims, current_claims, claim_response_mapping, gen_idx)

        # Add verbalized confidence for each claim
        for i, claim in enumerate(merged_claims):
            claim['verbalized_confidence'] = claim.get(
                'confidence', 0.5)

        return merged_claims, claim_response_mapping

    def _merge_claim_sets_llm(self, existing_claims: List[Dict], new_claims: List[Dict],
                             claim_response_mapping: Dict, response_idx: int) -> tuple[List[Dict], Dict]:
        """
        Merge two claim sets using LLM entailment checking following the paper's M function.
        M : P(C) × P(C) → P(C)

        Returns:
            merged_claims: Updated claim list
            updated_mapping: Updated mapping of claims to responses
        """
        if not new_claims:
            return existing_claims, claim_response_mapping

        if not existing_claims:
            # All new claims are added, map them to current response
            new_mapping = {}
            for i in range(len(new_claims)):
                new_mapping[i] = [response_idx]
            return new_claims, new_mapping

        # Find entailment relationships between claim sets
        entailment_pairs = self._find_entailment_pairs(existing_claims, new_claims)

        # Keep only new claims that are not entailed by existing claims
        merged_claims = existing_claims.copy()
        updated_mapping = claim_response_mapping.copy()

        for new_idx, new_claim in enumerate(new_claims):
            # Check if this new claim is entailed by any existing claim
            entailed_by_existing = None
            for existing_idx, new_claim_idx in entailment_pairs:
                if new_claim_idx == new_idx:
                    entailed_by_existing = existing_idx
                    break

            if entailed_by_existing is not None:
                # New claim is entailed by existing claim - add response mapping
                if entailed_by_existing in updated_mapping:
                    updated_mapping[entailed_by_existing].append(response_idx)
                else:
                    updated_mapping[entailed_by_existing] = [response_idx]
            else:
                # Claim is not entailed, add it to the merged set
                new_claim_idx = len(merged_claims)
                merged_claims.append(new_claim)
                updated_mapping[new_claim_idx] = [response_idx]

        return merged_claims, updated_mapping

    def _find_entailment_pairs(self, existing_claims: List[Dict], new_claims: List[Dict]) -> List[tuple]:
        """
        Find entailment pairs between existing and new claims using LLM.
        Returns list of pairs (idx_in_existing, idx_in_new) where existing[idx] entails new[idx].
        """
        # Prepare claims for LLM prompt
        existing_texts = [f"{i}: {claim['claim']}" for i, claim in enumerate(existing_claims)]
        new_texts = [f"{i}: {claim['claim']}" for i, claim in enumerate(new_claims)]

        prompt = f"""You are given two sets of claims. Find which claims in Set B are already covered by claims in Set A.

Set A (Existing claims):
{chr(10).join(existing_texts)}

Set B (New claims):
{chr(10).join(new_texts)}

For each claim in Set B, check if it says essentially the same thing as any claim in Set A (i.e. semantic equivalence even if worded differently).
Should be equivalent in meaning, if A said something turns up while B said something turns down, they are not equivalent. If A said something turns up and B said the same thing goes up, they are equivalent.

You must respond with ONLY a valid JSON array of pairs. Each pair is [existing_index, new_index] where Set A[existing_index] covers Set B[new_index].

Examples:
- If Set A[0] covers Set B[1] and Set A[2] covers Set B[0]: [[0, 1], [2, 0]]
- If no claims match: []
- If Set A[1] covers Set B[0]: [[1, 0]]

Response (JSON only):"""

        try:
            response = self.client.chat.completions.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}],
                max_tokens=200,
                temperature=0.1,
            )

            result_text = response.choices[0].message.content.strip()

            # Clean up response text to extract JSON
            if "```" in result_text:
                # Remove markdown code blocks
                result_text = result_text.split("```")[1]
                if result_text.startswith("json"):
                    result_text = result_text[4:]

            # Find JSON array
            start_idx = result_text.find('[')
            end_idx = result_text.rfind(']')

            if start_idx != -1 and end_idx != -1:
                json_text = result_text[start_idx:end_idx + 1]
            else:
                json_text = result_text

            # Parse JSON response
            import json
            entailment_pairs = json.loads(json_text)

            # Validate pairs
            valid_pairs = []
            for pair in entailment_pairs:
                if (isinstance(pair, list) and len(pair) == 2 and
                    0 <= pair[0] < len(existing_claims) and
                    0 <= pair[1] < len(new_claims)):
                    valid_pairs.append(tuple(pair))

            return valid_pairs

        except Exception as e:
            print(f"Error in LLM entailment checking: {e}")
            print(f"Raw LLM response: {result_text[:200] if 'result_text' in locals() else 'No response'}")
            # Fallback: no entailments found
            return []

    def process_single_question(self, question_data: Dict, question_idx: int) -> Dict:
        """Process a single question through the complete uncertainty analysis pipeline following the Graph-based Uncertainty paper."""
        print(f"\n--- Processing Question {question_idx + 1} ---")
        question = question_data['open_question']
        reference_answer = question_data['reference_answer']

        # Step 1: Generate multiple responses (following paper methodology)
        print("Step 1: Generating multiple responses...")
        responses = self.generate_multiple_responses(question)

        # Filter out empty responses
        valid_responses = [r for r in responses if r.strip()]
        if not valid_responses:
            print("Warning: No valid responses generated")
            return self._create_empty_result(question, reference_answer)

        # Step 2: Decompose each response into claims
        print("Step 2: Decomposing responses into claims...")
        all_claims_by_generation = []
        decomposition_results = []

        for i, response in enumerate(valid_responses):
            claims = self.claim_decomposer.decompose_response(response)
            decomposition_results.append({
                'response_id': i,
                'response': response,
                'claims': claims
            })
            all_claims_by_generation.append(claims)

        # Step 3: Create unified claim set (following paper's claim merging approach)
        print("Step 3: Creating unified claim set...")
        merged_claims, claim_response_mapping = self.match_claims_across_generations(
            all_claims_by_generation, valid_responses)

        if not merged_claims:
            print("Warning: No claims extracted")
            return self._create_empty_result(question, reference_answer)

        # Step 4: Build response-claim bipartite graph (core of the paper's method)
        print("Step 4: Building response-claim bipartite graph...")
        bipartite_graph = self.graph_builder.build_bipartite_graph_with_mapping(
            valid_responses, merged_claims, claim_response_mapping)

        # Step 5: Calculate uncertainty metrics using graph centrality (paper's main contribution)
        print("Step 5: Calculating uncertainty metrics using graph centrality...")
        claim_uncertainties = []
        for claim in merged_claims:
            uncertainty_metrics = self.uncertainty_calculator.calculate_uncertainty(
                claim, merged_claims, bipartite_graph
            )
            claim_uncertainties.append({
                'claim': claim['claim'],
                'confidence': claim.get('confidence', 0.5),
                'verbalized_confidence': claim.get('verbalized_confidence', 0.5),
                'uncertainty_metrics': uncertainty_metrics
            })

        # Step 6: Select final claims based on uncertainty threshold (uncertainty-aware decoding)
        print("Step 6: Selecting final claims using uncertainty-aware decoding...")
        final_claims, threshold_value = self.select_final_claims(
            claim_uncertainties, self.threshold_percentile
        )

        # Step 7: Generate final answer from selected claims (claim integration)
        print("Step 7: Generating final answer from selected claims...")
        final_answer = self.generate_final_answer(final_claims, question)

        # Get detailed graph statistics
        graph_stats = self.graph_builder.get_graph_statistics(bipartite_graph)

        # Prepare result
        result = {
            'question': question,
            'reference_answer': reference_answer,
            'multiple_responses': valid_responses,
            'decomposition_results': decomposition_results,
            'merged_claims': merged_claims,
            'claim_uncertainties': claim_uncertainties,
            'threshold_percentile': self.threshold_percentile,
            'threshold_value': threshold_value,
            'final_claims': final_claims,
            'final_answer': final_answer,
            'graph_stats': graph_stats,
            'bipartite_graph_info': self.graph_builder.visualize_graph_info(bipartite_graph)
        }

        return result

    def _create_empty_result(self, question: str, reference_answer: str) -> Dict:
        """Create an empty result structure for failed processing."""
        return {
            'question': question,
            'reference_answer': reference_answer,
            'multiple_responses': [],
            'decomposition_results': [],
            'merged_claims': [],
            'claim_uncertainties': [],
            'threshold_percentile': self.threshold_percentile,
            'threshold_value': 0.0,
            'final_claims': [],
            'final_answer': 'Unable to generate answer due to processing error.',
            'graph_stats': {
                'num_response_nodes': 0,
                'num_claim_nodes': 0,
                'num_edges': 0,
                'density': 0.0,
                'avg_response_degree': 0.0,
                'avg_claim_degree': 0.0,
                'num_components': 0,
                'largest_component_size': 0
            },
            'bipartite_graph_info': 'No graph generated due to processing error.'
        }

    def select_final_claims(self, claim_uncertainties: List[Dict], percentile: int) -> Tuple[List[Dict], float]:
        """Select final claims based on uncertainty threshold."""
        if not claim_uncertainties:
            return [], 0.0

        # Extract uncertainty scores (using closeness centrality as primary metric)
        uncertainty_scores = [
            claim['uncertainty_metrics'].get('closeness_centrality', 0.0)
            for claim in claim_uncertainties
        ]

        # Calculate threshold
        threshold_value = np.percentile(uncertainty_scores, percentile)

        # Select claims above threshold
        final_claims = [
            claim for claim in claim_uncertainties
            if claim['uncertainty_metrics'].get('closeness_centrality', 0.0) >= threshold_value
        ]

        # Sort by uncertainty score (descending)
        final_claims.sort(
            key=lambda x: x['uncertainty_metrics'].get(
                'closeness_centrality', 0.0),
            reverse=True
        )

        return final_claims, threshold_value

    def generate_final_answer(self, final_claims: List[Dict], question: str) -> str:
        """Generate final answer by combining selected claims using GPT-4o."""
        if not final_claims:
            return "Unable to generate a confident answer based on the analysis."

        # Extract claim texts
        claim_texts = [claim['claim'] for claim in final_claims]

        # Create prompt for concise final answer generation
        prompt = f"""Based on the following high-confidence claims, provide a CONCISE and DIRECT answer to the question. Keep the answer short (2-3 sentences maximum) and focus on the key conclusions.

Question: {question}

High-confidence claims:
{chr(10).join(f"- {claim}" for claim in claim_texts)}

Provide a concise, direct answer that directly addresses the specific question asked:"""

        try:
            response = self.client.chat.completions.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}],
                max_tokens=200,  # Limit length
                temperature=0.3,  # Lower temperature for more focused answers
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            print(f"Error generating final answer: {e}")
            # Fallback: create a simple summary
            return f"Based on the analysis of {len(final_claims)} high-confidence claims, the effects described in the question would have epidemiological implications."

    def run_analysis(self, input_file: str = "final_questions.json",
                     output_file: str = "uncertainty_analysis_results.json"):
        """Run complete uncertainty analysis on all questions with batch processing."""
        # Load questions
        try:
            with open(input_file, 'r', encoding='utf-8') as f:
                questions = json.load(f)
        except FileNotFoundError:
            print(f"Error: {input_file} not found.")
            return

        print(
            f"Starting uncertainty analysis on {len(questions)} questions...")
        print(f"Using {self.num_generations} generations per question")
        print(f"Results will be saved to {output_file}")

        # Initialize output file
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write('[\n')

        total_questions = len(questions)
        batch_size = 5
        results_buffer = {}  # Store results by index to maintain order

        print(
            f"Processing in batches of up to {batch_size} questions in parallel")

        # Process in batches
        for batch_start in range(0, total_questions, batch_size):
            batch_end = min(batch_start + batch_size, total_questions)
            batch_questions = [(i, questions[i])
                               for i in range(batch_start, batch_end)]

            print(
                f"\nProcessing batch {batch_start//batch_size + 1}: Questions {batch_start + 1}-{batch_end}")

            # Use ThreadPoolExecutor for parallel processing
            with ThreadPoolExecutor(max_workers=min(5, len(batch_questions))) as executor:
                # Submit all questions in the batch
                future_to_idx = {
                    executor.submit(self.process_single_question, question_data, idx): idx
                    for idx, question_data in batch_questions
                }

                # Collect results as they complete
                for future in as_completed(future_to_idx):
                    idx = future_to_idx[future]
                    try:
                        result = future.result()
                        results_buffer[idx] = result
                    except Exception as e:
                        print(f"Error processing question {idx + 1}: {e}")
                        traceback.print_exc()
                        # Create error result
                        error_result = {
                            'question': batch_questions[idx - batch_start][1].get('open_question', ''),
                            'reference_answer': batch_questions[idx - batch_start][1].get('reference_answer', ''),
                            'error': str(e),
                            'multiple_responses': [],
                            'claim_uncertainties': [],
                            'final_claims': [],
                            'final_answer': 'Error occurred during processing',
                            'question_idx': idx
                        }
                        results_buffer[idx] = error_result

            # Save results in order
            for i in range(batch_start, batch_end):
                if i in results_buffer:
                    self.save_result_incrementally(
                        results_buffer[i], output_file, i == 0)
                    print(f"Question {i + 1} completed and saved.")
                    # Clean up memory
                    del results_buffer[i]

        print(f"\nAnalysis complete! Results saved to {output_file}")

    def save_result_incrementally(self, result: Dict, output_file: str, is_first: bool):
        """Save result incrementally to JSON file with thread safety."""
        with self.file_lock:
            mode = 'a'  # Always append since we initialize the file separately

            with open(output_file, mode, encoding='utf-8') as f:
                if not is_first:
                    f.write(',\n')
                json.dump(result, f, ensure_ascii=False, indent=4)

    def finalize_json_file(self, output_file: str):
        """Close the JSON array in the output file."""
        with open(output_file, 'a', encoding='utf-8') as f:
            f.write('\n]')


def main():
    """Main function to run uncertainty analysis."""
    # Configuration (updated to 3 generations following paper's multi-sampling approach)
    num_generations = 5
    threshold_percentile = 40

    # Initialize and run analysis
    analyzer = UncertaintyAnalysis(
        num_generations=num_generations,
        threshold_percentile=threshold_percentile
    )

    try:
        analyzer.run_analysis()
        # Finalize JSON file
        analyzer.finalize_json_file("uncertainty_analysis_results.json")

    except KeyboardInterrupt:
        print("\nAnalysis interrupted by user.")
        # Still finalize the JSON file
        analyzer.finalize_json_file("uncertainty_analysis_results.json")
    except Exception as e:
        print(f"Analysis failed: {e}")
        traceback.print_exc()


if __name__ == "__main__":
    main()
