import os
import json
import time
import traceback
import numpy as np
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
import networkx as nx
from typing import Dict, List, Any, Tuple
import re
from sklearn.metrics.pairwise import cosine_similarity
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

# Import uncertainty tools
from uncertainty.claim_decomposer import ClaimDecomposer
from uncertainty.uncertainty_calculator import UncertaintyCalculator
from uncertainty.graph_builder import GraphBuilder


class UncertaintyAnalysis:
    def __init__(self, num_generations: int = 3, threshold_percentile: int = 40):
        """
        Initialize uncertainty analysis system following the paper's approach.

        Args:
            num_generations: Number of API calls per question (default: 3)
            threshold_percentile: Percentile for claim filtering threshold
        """
        self.num_generations = num_generations
        self.threshold_percentile = threshold_percentile

        # Load environment and configure API
        dotenv_path = os.path.join(os.path.dirname(__file__), '..', '..', '.env')
        load_dotenv(dotenv_path=dotenv_path)

        # Initialize OpenAI client for GPT-4o
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

        # Initialize components
        self.claim_decomposer = ClaimDecomposer()  # Already uses GPT-4o
        self.uncertainty_calculator = UncertaintyCalculator(
            self.client)  # Updated to use OpenAI
        self.graph_builder = GraphBuilder()

        # Thread lock for safe file writing
        self.file_lock = threading.Lock()

        print(
            f"Initialized UncertaintyAnalysis with {num_generations} generations per question")

    def generate_multiple_responses(self, question: str) -> List[str]:
        """Generate multiple responses for a given question using GPT-4o."""
        responses = []

        prompt = f"""Please provide a comprehensive answer to the following climate-related question. 
        Focus on providing factual, well-reasoned information based on climate science.

        Question: {question}

        Answer:"""

        for i in range(self.num_generations):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=1000,
                    temperature=0.7,  # Some temperature for diversity
                    top_p=0.9,
                )
                responses.append(response.choices[0].message.content.strip())
                print(f"  Generated response {i+1}/{self.num_generations}")
                time.sleep(1)  # Rate limiting

            except Exception as e:
                print(f"Error generating response {i+1}: {e}")
                responses.append("")
                time.sleep(5)  # Longer wait on error

        return responses

    def match_claims_across_generations(self, all_claims: List[List[Dict]], responses: List[str]) -> tuple[List[Dict], Dict]:
        """
        Match and merge claims across different generations following the paper's approach.
        Uses LLM-based entailment checking instead of semantic similarity.

        Returns:
            merged_claims: List of unique claims
            claim_response_mapping: Dict mapping claim indices to response indices that entail them
        """
        if not all_claims:
            return [], {}

        # Initialize mapping to track which responses entail which claims
        claim_response_mapping = {}

        # Step 2 from paper: Sequential claim merging using LLM entailment
        # Start with claims from the first generation: C^(1) = C_r1
        merged_claims = all_claims[0].copy()

        # Initialize mapping for first generation claims
        for claim_idx in range(len(merged_claims)):
            claim_response_mapping[claim_idx] = [0]  # Response 0 entails these claims

        # Process subsequent generations: C^(i) = M(C^(i-1), C_ri) for i ∈ {2, ..., N}
        for gen_idx in range(1, len(all_claims)):
            current_claims = all_claims[gen_idx]

            # Merge current claims into the existing set using LLM entailment
            merged_claims, claim_response_mapping = self._merge_claim_sets_llm(
                merged_claims, current_claims, claim_response_mapping, gen_idx)

        # Add verbalized confidence for each claim
        for i, claim in enumerate(merged_claims):
            claim['verbalized_confidence'] = claim.get(
                'confidence', 0.5)

        return merged_claims, claim_response_mapping

    def _merge_claim_sets_llm(self, existing_claims: List[Dict], new_claims: List[Dict],
                             claim_response_mapping: Dict, response_idx: int) -> tuple[List[Dict], Dict]:
        """
        Merge two claim sets using LLM entailment checking following the paper's M function.
        M : P(C) × P(C) → P(C)

        Returns:
            merged_claims: Updated claim list
            updated_mapping: Updated mapping of claims to responses
        """
        if not new_claims:
            return existing_claims, claim_response_mapping

        if not existing_claims:
            # All new claims are added, map them to current response
            new_mapping = {}
            for i in range(len(new_claims)):
                new_mapping[i] = [response_idx]
            return new_claims, new_mapping

        # Find entailment relationships between claim sets
        entailment_pairs = self._find_entailment_pairs(existing_claims, new_claims)

        # Keep only new claims that are not entailed by existing claims
        merged_claims = existing_claims.copy()
        updated_mapping = claim_response_mapping.copy()

        for new_idx, new_claim in enumerate(new_claims):
            # Check if this new claim is entailed by any existing claim
            entailed_by_existing = None
            for existing_idx, new_claim_idx in entailment_pairs:
                if new_claim_idx == new_idx:
                    entailed_by_existing = existing_idx
                    break

            if entailed_by_existing is not None:
                # New claim is entailed by existing claim - add response mapping
                if entailed_by_existing in updated_mapping:
                    updated_mapping[entailed_by_existing].append(response_idx)
                else:
                    updated_mapping[entailed_by_existing] = [response_idx]
            else:
                # Claim is not entailed, add it to the merged set
                new_claim_idx = len(merged_claims)
                merged_claims.append(new_claim)
                updated_mapping[new_claim_idx] = [response_idx]

        return merged_claims, updated_mapping

    def _find_entailment_pairs(self, existing_claims: List[Dict], new_claims: List[Dict]) -> List[tuple]:
        """
        Find entailment pairs between existing and new claims using LLM.
        Returns list of pairs (idx_in_existing, idx_in_new) where existing[idx] entails new[idx].
        """
        # Prepare claims for LLM prompt
        existing_texts = [f"{i}: {claim['claim']}" for i, claim in enumerate(existing_claims)]
        new_texts = [f"{i}: {claim['claim']}" for i, claim in enumerate(new_claims)]

        prompt = f"""You are given two sets of claims. Find which claims in Set B are already covered by claims in Set A.

Set A (Existing claims):
{chr(10).join(existing_texts)}

Set B (New claims):
{chr(10).join(new_texts)}

For each claim in Set B, check if it says essentially the same thing as any claim in Set A (i.e. semantic equivalence even if worded differently).
Should be equivalent in meaning, if A said something turns up while B said something turns down, they are not equivalent. If A said something turns up and B said the same thing goes up, they are equivalent.

You must respond with ONLY a valid JSON array of pairs. Each pair is [existing_index, new_index] where Set A[existing_index] covers Set B[new_index].

Examples:
- If Set A[0] covers Set B[1] and Set A[2] covers Set B[0]: [[0, 1], [2, 0]]
- If no claims match: []
- If Set A[1] covers Set B[0]: [[1, 0]]

Response (JSON only):"""

        try:
            response = self.client.chat.completions.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}],
                max_tokens=200,
                temperature=0.1,
            )

            result_text = response.choices[0].message.content.strip()

            # Clean up response text to extract JSON
            if "```" in result_text:
                # Remove markdown code blocks
                result_text = result_text.split("```")[1]
                if result_text.startswith("json"):
                    result_text = result_text[4:]

            # Find JSON array
            start_idx = result_text.find('[')
            end_idx = result_text.rfind(']')

            if start_idx != -1 and end_idx != -1:
                json_text = result_text[start_idx:end_idx + 1]
            else:
                json_text = result_text

            # Parse JSON response
            import json
            entailment_pairs = json.loads(json_text)

            # Validate pairs
            valid_pairs = []
            for pair in entailment_pairs:
                if (isinstance(pair, list) and len(pair) == 2 and
                    0 <= pair[0] < len(existing_claims) and
                    0 <= pair[1] < len(new_claims)):
                    valid_pairs.append(tuple(pair))

            return valid_pairs

        except Exception as e:
            print(f"Error in LLM entailment checking: {e}")
            print(f"Raw LLM response: {result_text[:200] if 'result_text' in locals() else 'No response'}")
            # Fallback: no entailments found
            return []

    def process_single_question(self, question_data: Dict, question_idx: int) -> Dict:
        """Process a single question through the complete uncertainty analysis pipeline following the Graph-based Uncertainty paper."""
        print(f"\n--- Processing Question {question_idx + 1} ---")
        question = question_data['open_question']
        reference_answer = question_data['reference_answer']

        # Step 1: Generate multiple responses (following paper methodology)
        print("Step 1: Generating multiple responses...")
        responses = self.generate_multiple_responses(question)

        # Filter out empty responses
        valid_responses = [r for r in responses if r.strip()]
        if not valid_responses:
            print("Warning: No valid responses generated")
            return self._create_empty_result(question, reference_answer)

        # Step 2: Decompose each response into claims
        print("Step 2: Decomposing responses into claims...")
        all_claims_by_generation = []
        decomposition_results = []

        for i, response in enumerate(valid_responses):
            claims = self.claim_decomposer.decompose_response(response)
            decomposition_results.append({
                'response_id': i,
                'response': response,
                'claims': claims
            })
            all_claims_by_generation.append(claims)

        # Step 3: Create unified claim set (following paper's claim merging approach)
        print("Step 3: Creating unified claim set...")
        merged_claims, claim_response_mapping = self.match_claims_across_generations(
            all_claims_by_generation, valid_responses)

        if not merged_claims:
            print("Warning: No claims extracted")
            return self._create_empty_result(question, reference_answer)

        # Step 4: Build response-claim bipartite graph (core of the paper's method)
        print("Step 4: Building response-claim bipartite graph...")
        bipartite_graph = self.graph_builder.build_bipartite_graph_with_mapping(
            valid_responses, merged_claims, claim_response_mapping)

        # Step 5: Calculate uncertainty metrics using graph centrality (paper's main contribution)
        print("Step 5: Calculating uncertainty metrics using graph centrality...")
        claim_uncertainties = []
        for claim in merged_claims:
            uncertainty_metrics = self.uncertainty_calculator.calculate_uncertainty(
                claim, merged_claims, bipartite_graph
            )
            claim_uncertainties.append({
                'claim': claim['claim'],
                'confidence': claim.get('confidence', 0.5),
                'verbalized_confidence': claim.get('verbalized_confidence', 0.5),
                'uncertainty_metrics': uncertainty_metrics
            })

        # Step 6: Select final claims based on uncertainty threshold (uncertainty-aware decoding)
        print("Step 6: Selecting final claims using uncertainty-aware decoding...")
        final_claims, threshold_value = self.select_final_claims(
            claim_uncertainties, self.threshold_percentile
        )

        # Step 7: Generate final answer from selected claims (claim integration)
        print("Step 7: Generating final answer from selected claims...")
        final_answer = self.generate_final_answer(final_claims, question)

        # Get detailed graph statistics
        graph_stats = self.graph_builder.get_graph_statistics(bipartite_graph)

        # Prepare result
        result = {
            'question': question,
            'reference_answer': reference_answer,
            'multiple_responses': valid_responses,
            'decomposition_results': decomposition_results,
            'merged_claims': merged_claims,
            'claim_uncertainties': claim_uncertainties,
            'threshold_percentile': self.threshold_percentile,
            'threshold_value': threshold_value,
            'final_claims': final_claims,
            'final_answer': final_answer,
            'graph_stats': graph_stats,
            'bipartite_graph_info': self.graph_builder.visualize_graph_info(bipartite_graph)
        }

        return result

    def _create_empty_result(self, question: str, reference_answer: str) -> Dict:
        """Create an empty result structure for failed processing."""
        return {
            'question': question,
            'reference_answer': reference_answer,
            'multiple_responses': [],
            'decomposition_results': [],
            'merged_claims': [],
            'claim_uncertainties': [],
            'threshold_percentile': self.threshold_percentile,
            'threshold_value': 0.0,
            'final_claims': [],
            'final_answer': 'Unable to generate answer due to processing error.',
            'graph_stats': {
                'num_response_nodes': 0,
                'num_claim_nodes': 0,
                'num_edges': 0,
                'density': 0.0,
                'avg_response_degree': 0.0,
                'avg_claim_degree': 0.0,
                'num_components': 0,
                'largest_component_size': 0
            },
            'bipartite_graph_info': 'No graph generated due to processing error.'
        }

    def select_final_claims(self, claim_uncertainties: List[Dict], percentile: int) -> Tuple[List[Dict], float]:
        """Select final claims based on uncertainty threshold."""
        if not claim_uncertainties:
            return [], 0.0

        # Extract uncertainty scores (using closeness centrality as primary metric)
        uncertainty_scores = [
            claim['uncertainty_metrics'].get('closeness_centrality', 0.0)
            for claim in claim_uncertainties
        ]

        # Calculate threshold
        threshold_value = np.percentile(uncertainty_scores, percentile)

        # Select claims above threshold
        final_claims = [
            claim for claim in claim_uncertainties
            if claim['uncertainty_metrics'].get('closeness_centrality', 0.0) >= threshold_value
        ]

        # Sort by uncertainty score (descending)
        final_claims.sort(
            key=lambda x: x['uncertainty_metrics'].get(
                'closeness_centrality', 0.0),
            reverse=True
        )

        return final_claims, threshold_value

    def generate_final_answer(self, final_claims: List[Dict], question: str) -> str:
        """Generate final answer by combining selected claims using GPT-4o."""
        if not final_claims:
            return "Unable to generate a confident answer based on the analysis."

        # Extract claim texts
        claim_texts = [claim['claim'] for claim in final_claims]

        # Create prompt for concise final answer generation
        prompt = f"""Based on the following high-confidence claims, provide a CONCISE and DIRECT answer to the question. Keep the answer short (2-3 sentences maximum) and focus on the key conclusions.

Question: {question}

High-confidence claims:
{chr(10).join(f"- {claim}" for claim in claim_texts)}

Provide a concise, direct answer that directly addresses the specific question asked:"""

        try:
            response = self.client.chat.completions.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}],
                max_tokens=200,  # Limit length
                temperature=0.3,  # Lower temperature for more focused answers
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            print(f"Error generating final answer: {e}")
            # Fallback: create a simple summary
            return f"Based on the analysis of {len(final_claims)} high-confidence claims, the effects described in the question would result in localized climate impacts."

    def run_analysis(self, input_file: str = "final_questions.json",
                     output_file: str = "uncertainty_analysis_results.json"):
        """Run complete uncertainty analysis on all questions with batch processing."""
        # Load questions
        try:
            with open(input_file, 'r', encoding='utf-8') as f:
                questions = json.load(f)
        except FileNotFoundError:
            print(f"Error: {input_file} not found.")
            return

        print(
            f"Starting uncertainty analysis on {len(questions)} questions...")
        print(f"Using {self.num_generations} generations per question")
        print(f"Results will be saved to {output_file}")

        # Initialize output file
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write('[\n')

        total_questions = len(questions)
        batch_size = 5
        results_buffer = {}  # Store results by index to maintain order

        print(
            f"Processing in batches of up to {batch_size} questions in parallel")

        # Process in batches
        for batch_start in range(0, total_questions, batch_size):
            batch_end = min(batch_start + batch_size, total_questions)
            batch_questions = [(i, questions[i])
                               for i in range(batch_start, batch_end)]

            print(
                f"\nProcessing batch {batch_start//batch_size + 1}: Questions {batch_start + 1}-{batch_end}")

            # Use ThreadPoolExecutor for parallel processing
            with ThreadPoolExecutor(max_workers=min(5, len(batch_questions))) as executor:
                # Submit all questions in the batch
                future_to_idx = {
                    executor.submit(self.process_single_question, question_data, idx): idx
                    for idx, question_data in batch_questions
                }

                # Collect results as they complete
                for future in as_completed(future_to_idx):
                    idx = future_to_idx[future]
                    try:
                        result = future.result()
                        results_buffer[idx] = result
                    except Exception as e:
                        print(f"Error processing question {idx + 1}: {e}")
                        traceback.print_exc()
                        # Create error result
                        error_result = {
                            'question': batch_questions[idx - batch_start][1].get('open_question', ''),
                            'reference_answer': batch_questions[idx - batch_start][1].get('reference_answer', ''),
                            'error': str(e),
                            'multiple_responses': [],
                            'claim_uncertainties': [],
                            'final_claims': [],
                            'final_answer': 'Error occurred during processing',
                            'question_idx': idx
                        }
                        results_buffer[idx] = error_result

            # Save results in order
            for i in range(batch_start, batch_end):
                if i in results_buffer:
                    self.save_result_incrementally(
                        results_buffer[i], output_file, i == 0)
                    print(f"Question {i + 1} completed and saved.")
                    # Clean up memory
                    del results_buffer[i]

        print(f"\nAnalysis complete! Results saved to {output_file}")

    def save_result_incrementally(self, result: Dict, output_file: str, is_first: bool):
        """Save result incrementally to JSON file with thread safety."""
        with self.file_lock:
            mode = 'a'  # Always append since we initialize the file separately

            with open(output_file, mode, encoding='utf-8') as f:
                if not is_first:
                    f.write(',\n')
                json.dump(result, f, ensure_ascii=False, indent=4)

    def finalize_json_file(self, output_file: str):
        """Close the JSON array in the output file."""
        with open(output_file, 'a', encoding='utf-8') as f:
            f.write('\n]')


def main():
    """Main function to run uncertainty analysis."""
    # Configuration (updated to 3 generations following paper's multi-sampling approach)
    num_generations = 5
    threshold_percentile = 40

    # Initialize and run analysis
    analyzer = UncertaintyAnalysis(
        num_generations=num_generations,
        threshold_percentile=threshold_percentile
    )

    try:
        analyzer.run_analysis()
        # Finalize JSON file
        analyzer.finalize_json_file("uncertainty_analysis_results.json")

    except KeyboardInterrupt:
        print("\nAnalysis interrupted by user.")
        # Still finalize the JSON file
        analyzer.finalize_json_file("uncertainty_analysis_results.json")
    except Exception as e:
        print(f"Analysis failed: {e}")
        traceback.print_exc()


if __name__ == "__main__":
    main()
