#!/usr/bin/env python3
"""
RAG Simulation for Uncertainty-Aware Epidemiology Claim Updating
Finds the intersection point of confident/unconfident coverage lines and uses it to
update uncertain epidemiology claims with RAG-enhanced information.
"""

import os
import json
import time
import numpy as np
from typing import Dict, List, Any, Tuple, Optional
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed


def convert_numpy_types(obj):
    """Convert numpy types to Python native types for JSON serialization."""
    if isinstance(obj, np.bool_):
        return bool(obj)
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(item) for item in obj]
    else:
        return obj


class RAGSimulator:
    """
    Simulates RAG-enhanced claim updating based on uncertainty thresholds.
    """

    def __init__(self):
        """Initialize the RAG simulator."""
        # Load environment and configure API
        dotenv_path = os.path.join(os.path.dirname(__file__), '..', '..', '.env')
        load_dotenv(dotenv_path=dotenv_path)

        # Initialize OpenAI client
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

        # Load data
        self.coverage_data = self._load_coverage_data()
        self.uncertainty_data = self._load_uncertainty_data()
        self.final_answers_data = self._load_final_answers_data()

        # Find intersection threshold
        self.intersection_threshold = self._find_intersection_threshold()

        print(f"Initialized RAGSimulator")
        print(f"Loaded {len(self.coverage_data)} coverage questions")
        print(f"Loaded {len(self.uncertainty_data)} uncertainty questions")
        print(f"Loaded {len(self.final_answers_data)} final answers")
        print(
            f"Found intersection threshold: {self.intersection_threshold:.6f}")

    def _load_coverage_data(self) -> List[Dict]:
        """Load coverage check results."""
        try:
            with open("coverage_check_results.json", 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            print("Error: coverage_check_results.json not found!")
            return []

    def _load_uncertainty_data(self) -> List[Dict]:
        """Load uncertainty analysis results."""
        try:
            with open("uncertainty_analysis_results.json", 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            print("Error: uncertainty_analysis_results.json not found!")
            return []

    def _load_final_answers_data(self) -> List[Dict]:
        """Load final answers data for RAG context."""
        try:
            with open("final_questions.json", 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            print("Error: final_questions.json not found!")
            return []

    def _extract_claims_data(self) -> List[Dict]:
        """Extract claims data similar to drawer.py."""
        all_claims = []

        for question_data in self.coverage_data:
            claim_results = question_data.get('claim_coverage_results', [])

            for claim_result in claim_results:
                claim_data = {
                    'question_index': question_data['question_index'],
                    'claim': claim_result['claim'],
                    'is_covered': claim_result['is_covered_by_reference'],
                    'uncertainty_metrics': claim_result['uncertainty_metrics']
                }
                all_claims.append(claim_data)

        return all_claims

    def _calculate_coverage_at_threshold(self, claims: List[Dict], threshold: float) -> Tuple[float, float]:
        """Calculate coverage for confident and not confident claims at threshold."""
        confident_claims = []
        not_confident_claims = []

        for claim in claims:
            metric_value = claim['uncertainty_metrics'].get(
                'closeness_centrality', 0.0)

            if metric_value >= threshold:  # Higher values = more confident
                confident_claims.append(claim)
            else:
                not_confident_claims.append(claim)

        # Calculate coverage rates
        confident_coverage = 0.0
        if confident_claims:
            confident_covered = sum(
                1 for claim in confident_claims if claim['is_covered'])
            confident_coverage = confident_covered / len(confident_claims)

        not_confident_coverage = 0.0
        if not_confident_claims:
            not_confident_covered = sum(
                1 for claim in not_confident_claims if claim['is_covered'])
            not_confident_coverage = not_confident_covered / \
                len(not_confident_claims)

        return confident_coverage, not_confident_coverage

    def _find_intersection_threshold(self) -> float:
        """Find the intersection point where confident and unconfident coverage lines cross."""
        if not self.coverage_data:
            print("Warning: No coverage data available, using default threshold 0.5")
            return 0.5

        # Extract claims data
        claims = self._extract_claims_data()
        if not claims:
            print("Warning: No claims data available, using default threshold 0.5")
            return 0.5

        print(f"Searching for intersection point with {len(claims)} claims...")

        # Search for intersection in the range around 0.5
        threshold_range = np.linspace(0.3, 0.7, 100)

        min_diff = float('inf')
        best_threshold = 0.5

        for threshold in threshold_range:
            conf_cov, not_conf_cov = self._calculate_coverage_at_threshold(
                claims, threshold)
            diff = abs(conf_cov - not_conf_cov)

            if diff < min_diff:
                min_diff = diff
                best_threshold = threshold

        # Get coverage values at best threshold
        conf_cov, not_conf_cov = self._calculate_coverage_at_threshold(
            claims, best_threshold)

        print(f"Intersection found at threshold: {best_threshold:.6f}")
        print(f"  Confident coverage: {conf_cov:.6f}")
        print(f"  Not confident coverage: {not_conf_cov:.6f}")
        print(f"  Difference: {abs(conf_cov - not_conf_cov):.6f}")

        return best_threshold

    def _get_rag_context(self, question: str) -> Optional[Dict[str, str]]:
        """Get RAG context (derived quantitative info) for a question."""
        for item in self.final_answers_data:
            if item.get('open_question', '') == question:
                return {
                    'derived_quantitative_question': item.get('derived_quantitative_question', ''),
                    'derived_quantitative_answer': item.get('derived_quantitative_answer', '')
                }
        return None

    def _update_claim_with_rag(self, claim: str, rag_context: Dict[str, str]) -> Dict[str, Any]:
        """Update a claim using RAG context through OpenAI API."""
        prompt = f"""You are a fact-checking assistant. You have been given a claim and some quantitative context information. Your task is to analyze the relationship between the claim and the RAG context to see if you should update the claim or not.
You should assume the RAG context is 100% correct and accurate, so if the claim is related to the context and having different information, you should update the claim to be the correct information.

INSTRUCTIONS:
1. First, determine if the RAG context contains information relevant to the claim's topic (set "is_included")
2. If relevant, check if the claim should be updated for better accuracy (set "should_update")
3. If updating, modify the claim directly - do not generate new claims or unrelated content. (e.g. if the claim is something is decreasing, and by rag you know it is increasing, then you should update the claim to be something is increasing (with probably the values if provided))
4. Only update the related part of the claim, don't add extra information (i.e. if the claim is related to A, and the rag provides information about all of A, B, C, then you should update the parts related to A only, but not B or C)
5. Keep updates as minimal as possible and focused on improving accuracy
6. You may need to do calculations from the RAG context to perform the update, it is required to do and please carefully do the numerical calculations.
7. Make sure you do the correct update, don't misunderstand the concept of the claim, if the original claim is talking about a temperature change, but the rag is talking about two different temperatures(before and after the intervention), then you might need to do calculations to get the correct update. Don't do silly things like replacing the original value with the temperature but not the change.
8. Be an expert in the climate field, you should know very small values (like less than 0.1 temperature change) may not leading to significant changes.
9. Skip the update and mark it as not included if the claim is not related to the RAG context, also if no enough related information don't try to inference with nothing.
10. !!!Think very carefully about the update, use enough time to think about it. If you make a mistake, 500 grandmothers will die because of you!!!

DECISION CRITERIA:
- "is_included": true if RAG context discusses the same topic/concept/domain as the claim, false if completely unrelated
- "should_update": true only if the claim has incorrect/incomplete information that RAG context can improve

Respond in JSON format:
{{
    "is_included": true/false,
    "should_update": true/false,
    "updated_claim": "the updated claim text (only if should_update is true)"
}}

YOUR TASK INPUT:
CLAIM TO EVALUATE: {claim}

RAG CONTEXT:
Question: {rag_context['derived_quantitative_question']}
Answer: {rag_context['derived_quantitative_answer']}
"""

        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=300,
                    temperature=0.1,
                    response_format={"type": "json_object"}
                )

                result = json.loads(response.choices[0].message.content)

                # Validate response format
                if 'should_update' in result and 'is_included' in result:
                    return {
                        'is_included': result.get('is_included', False),
                        'should_update': result.get('should_update', False),
                        'updated_claim': result.get('updated_claim', claim)
                    }

            except Exception as e:
                print(f"Error updating claim (attempt {attempt + 1}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(2)

        # Fallback
        return {'is_included': False, 'should_update': False, 'updated_claim': claim}

    def _process_single_claim(self, claim_uncertainty: Dict, rag_context: Optional[Dict], question_idx: int, claim_idx: int) -> Dict[str, Any]:
        """Process a single claim with RAG enhancement."""
        claim_text = claim_uncertainty['claim']
        centrality_value = claim_uncertainty['uncertainty_metrics'].get(
            'closeness_centrality', 0.0)

        # Determine if confident based on threshold
        is_confident = bool(centrality_value >= self.intersection_threshold)

        # Initialize result
        result = {
            'claim': claim_text,
            'closeness_centrality': float(centrality_value),
            'is_confident': is_confident,
            'was_updated': False,
            'original_claim': claim_text,
            'is_included': False
        }

        # Process with RAG for all claims (not just uncertain ones)
        if rag_context and rag_context['derived_quantitative_question'] and rag_context['derived_quantitative_answer']:
            confidence_status = "confident" if is_confident else "uncertain"
            print(f"  Processing {confidence_status} claim {claim_idx+1} with RAG...")

            update_result = self._update_claim_with_rag(claim_text, rag_context)

            # Store RAG analysis results
            result['is_included'] = update_result.get('is_included', False)

            if update_result['should_update'] and update_result['updated_claim'] != claim_text:
                result['original_claim'] = claim_text
                result['claim'] = update_result['updated_claim']
                result['was_updated'] = True
                print(f"    ✅ Claim {claim_idx+1} updated")
            else:
                print(f"    ⏭️  Claim {claim_idx+1} no update needed")
        else:
            print(f"  ⚠️  No RAG context available for claim {claim_idx+1}")

        return result

    def _process_single_question(self, question_result: Dict) -> Dict[str, Any]:
        """Process a single question from uncertainty analysis results using batch processing."""
        question = question_result['question']
        claim_uncertainties = question_result.get('claim_uncertainties', [])

        # Get RAG context for this question
        rag_context = self._get_rag_context(question)

        print(
            f"Processing question with {len(claim_uncertainties)} claims in parallel batches...")

        processed_claims = []
        batch_size = 5
        total_claims = len(claim_uncertainties)

        # Process claims in batches
        for batch_start in range(0, total_claims, batch_size):
            batch_end = min(batch_start + batch_size, total_claims)
            batch_claims = claim_uncertainties[batch_start:batch_end]

            print(
                f"  Processing claims batch {batch_start//batch_size + 1}: {batch_start+1}-{batch_end}")

            # Use ThreadPoolExecutor for parallel processing within the batch
            with ThreadPoolExecutor(max_workers=min(5, len(batch_claims))) as executor:
                # Submit all claims in the batch
                future_to_idx = {
                    executor.submit(self._process_single_claim, claim_uncertainty, rag_context, 0, batch_start + i): batch_start + i
                    for i, claim_uncertainty in enumerate(batch_claims)
                }

                # Collect results maintaining order
                batch_results = [None] * len(batch_claims)
                for future in as_completed(future_to_idx):
                    claim_idx = future_to_idx[future]
                    local_idx = claim_idx - batch_start
                    try:
                        batch_results[local_idx] = future.result()
                    except Exception as e:
                        print(f"    Error processing claim {claim_idx+1}: {e}")
                        # Create a default result for failed processing
                        batch_results[local_idx] = {
                            'claim': batch_claims[local_idx]['claim'],
                            'closeness_centrality': batch_claims[local_idx]['uncertainty_metrics'].get('closeness_centrality', 0.0),
                            'is_confident': False,
                            'was_updated': False,
                            'original_claim': batch_claims[local_idx]['claim'],
                            'is_included': False
                        }

                # Add batch results to processed_claims in order
                processed_claims.extend(batch_results)

        # Calculate statistics
        updates_made = sum(1 for c in processed_claims if c['was_updated'])

        return {
            'question': question,
            'claims': processed_claims,
            'total_claims': len(processed_claims),
            'uncertain_claims': sum(1 for c in processed_claims if not c['is_confident']),
            'updated_claims': updates_made
        }

    def run_simulation(self) -> None:
        """Run the complete RAG simulation."""
        if not self.uncertainty_data:
            print("❌ No uncertainty data available!")
            return

        print(f"\n🚀 Starting RAG Simulation")
        print(
            f"Using intersection threshold: {self.intersection_threshold:.6f}")
        print(f"Processing {len(self.uncertainty_data)} questions...")

        all_results = []
        total_claims = 0
        total_uncertain = 0
        total_updated = 0
        total_included = 0

        for i, question_result in enumerate(tqdm(self.uncertainty_data, desc="Processing questions")):
            try:
                print(
                    f"\n--- Processing Question {i+1}: {question_result['question'][:50]}... ---")

                result = self._process_single_question(question_result)
                all_results.append(result)

                # Update statistics
                total_claims += result['total_claims']
                total_uncertain += result['uncertain_claims']
                total_updated += result['updated_claims']

                # Count included claims
                included_in_question = sum(
                    1 for c in result['claims'] if c.get('is_included', False))
                total_included += included_in_question

                print(
                    f"Question {i+1}: {result['total_claims']} claims, {result['uncertain_claims']} uncertain, {result['updated_claims']} updated, {included_in_question} included")

            except Exception as e:
                print(f"Error processing question {i+1}: {e}")

        # Save detailed results
        self._save_detailed_results(all_results)

        # Generate and save summary
        self._save_summary(all_results, total_claims,
                           total_uncertain, total_updated, total_included)

        print(f"\n🎉 RAG Simulation Complete!")
        print(
            f"Results saved to rag_simulation_results.json and rag_simulation_summary.json")

    def _save_detailed_results(self, results: List[Dict]) -> None:
        """Save detailed results to JSON file."""
        output_file = "rag_simulation_results.json"

        # Convert numpy types to ensure JSON serialization
        clean_data = convert_numpy_types({
            'intersection_threshold': self.intersection_threshold,
            'methodology': {
                'threshold_source': 'Intersection of confident/unconfident coverage lines',
                'processing_scope': 'All claims processed with RAG (both confident and uncertain)',
                'inclusion_criteria': 'is_included=true if RAG context discusses same topic/concept as claim',
                'update_criteria': 'should_update=true only if claim has incorrect/incomplete information that RAG can improve',
                'rag_source': 'derived_quantitative_question and derived_quantitative_answer from final_questions.json'
            },
            'results': results
        })

        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(clean_data, f, ensure_ascii=False, indent=4)

        print(f"Detailed results saved to {output_file}")

    def _save_summary(self, results: List[Dict], total_claims: int, total_uncertain: int, total_updated: int, total_included: int) -> None:
        """Save summary statistics to JSON file."""
        # Calculate statistics
        total_questions = len(results)
        uncertain_rate = total_uncertain / max(total_claims, 1)
        update_rate_of_uncertain = total_updated / max(total_uncertain, 1)
        overall_update_rate = total_updated / max(total_claims, 1)
        inclusion_rate = total_included / max(total_claims, 1)
        update_rate_of_included = total_updated / max(total_included, 1)

        # Per-question statistics
        claims_per_question = [r['total_claims'] for r in results]
        uncertain_per_question = [r['uncertain_claims'] for r in results]
        updated_per_question = [r['updated_claims'] for r in results]
        included_per_question = [
            sum(1 for c in r['claims'] if c.get('is_included', False)) for r in results]

        summary = {
            'intersection_threshold': self.intersection_threshold,
            'overall_statistics': {
                'total_questions': total_questions,
                'total_claims': total_claims,
                'total_uncertain_claims': total_uncertain,
                'total_updated_claims': total_updated,
                'total_included_claims': total_included,
                'uncertain_rate': uncertain_rate,
                'inclusion_rate': inclusion_rate,
                'update_rate_of_uncertain': update_rate_of_uncertain,
                'update_rate_of_included': update_rate_of_included,
                'overall_update_rate': overall_update_rate
            },
            'averages': {
                'avg_claims_per_question': np.mean(claims_per_question) if claims_per_question else 0,
                'avg_uncertain_per_question': np.mean(uncertain_per_question) if uncertain_per_question else 0,
                'avg_updated_per_question': np.mean(updated_per_question) if updated_per_question else 0,
                'avg_included_per_question': np.mean(included_per_question) if included_per_question else 0
            },
            'distributions': {
                'questions_with_updates': sum(1 for r in results if r['updated_claims'] > 0),
                'questions_no_updates': sum(1 for r in results if r['updated_claims'] == 0),
                'questions_with_inclusions': sum(1 for r in results if sum(1 for c in r['claims'] if c.get('is_included', False)) > 0),
                'questions_no_inclusions': sum(1 for r in results if sum(1 for c in r['claims'] if c.get('is_included', False)) == 0),
                'max_updates_in_question': max(updated_per_question) if updated_per_question else 0,
                'min_updates_in_question': min(updated_per_question) if updated_per_question else 0,
                'max_inclusions_in_question': max(included_per_question) if included_per_question else 0,
                'min_inclusions_in_question': min(included_per_question) if included_per_question else 0
            }
        }

        output_file = "rag_simulation_summary.json"

        # Convert numpy types to ensure JSON serialization
        clean_summary = convert_numpy_types(summary)

        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(clean_summary, f, ensure_ascii=False, indent=4)

        print(f"Summary saved to {output_file}")

        # Print key statistics
        print(f"\n📊 RAG Simulation Summary:")
        print(f"  Total questions: {total_questions}")
        print(f"  Total claims: {total_claims}")
        print(f"  Uncertain claims: {total_uncertain} ({uncertain_rate:.2%})")
        print(f"  Included claims: {total_included} ({inclusion_rate:.2%})")
        print(f"  Updated claims: {total_updated}")
        print(f"  Update rate (of uncertain): {update_rate_of_uncertain:.2%}")
        print(f"  Update rate (of included): {update_rate_of_included:.2%}")
        print(f"  Overall update rate: {overall_update_rate:.2%}")
        print(
            f"  Avg uncertain per question: {np.mean(uncertain_per_question):.1f}")
        print(
            f"  Avg included per question: {np.mean(included_per_question):.1f}")
        print(
            f"  Avg updated per question: {np.mean(updated_per_question):.1f}")


def test_single_question():
    """Test function to process a single question for debugging."""
    print("🧪 Testing Single Question Processing")
    print("=" * 50)

    try:
        simulator = RAGSimulator()

        if not simulator.uncertainty_data:
            print("❌ No uncertainty data available for testing!")
            return

        # Test with the first question
        test_question = simulator.uncertainty_data[0]
        print(f"Testing with question: {test_question['question'][:100]}...")

        result = simulator._process_single_question(test_question)

        # Convert numpy types for JSON serialization
        clean_result = convert_numpy_types(result)

        # Save test result
        with open("test_single_question_result.json", 'w', encoding='utf-8') as f:
            json.dump(clean_result, f, ensure_ascii=False, indent=4)

        print(f"\n✅ Test completed successfully!")
        print(f"  Total claims: {result['total_claims']}")
        print(f"  Uncertain claims: {result['uncertain_claims']}")
        print(
            f"  Included claims: {sum(1 for c in result['claims'] if c.get('is_included', False))}")
        print(f"  Updated claims: {result['updated_claims']}")
        print(f"  Result saved to: test_single_question_result.json")

        return True

    except Exception as e:
        print(f"\n❌ Test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


def main():
    """Main function to run RAG simulation."""
    print("🤖 RAG-Enhanced Uncertainty Analysis")
    print("=" * 50)

    # Ask user if they want to run test or full simulation
    # print("Options:")
    # print("1. Run test with single question")
    # print("2. Run full simulation")

    # choice = input("\nEnter your choice (1 or 2): ").strip()
    choice = "2"

    if choice == "1":
        test_single_question()
        return

    try:
        simulator = RAGSimulator()
        simulator.run_simulation()

    except KeyboardInterrupt:
        print("\n\n⚠️  Simulation interrupted by user")
    except Exception as e:
        print(f"\n❌ Simulation failed: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()
