#!/usr/bin/env python3
"""
RAG Claim Correctness Analysis for Epidemiology - Evaluate all claims before and after RAG updates
Analyzes whether epidemiology claims help correctly derive the ground truth answer
"""

import os
import json
import time
from typing import Dict, List, Any, Tuple
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

load_dotenv()


def convert_numpy_types(obj):
    """Convert numpy types to Python native types for JSON serialization."""
    import numpy as np
    if isinstance(obj, np.bool_):
        return bool(obj)
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(item) for item in obj]
    else:
        return obj


class RAGClaimCorrectnessAnalyzer:
    """Analyze correctness of all claims before and after RAG updates."""

    def __init__(self):
        """Initialize the correctness analyzer."""
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

        # Load data
        self.rag_results = self._load_rag_results()
        self.final_answers = self._load_final_answers()

        print(f"Initialized RAGClaimCorrectnessAnalyzer")
        print(f"Loaded {len(self.rag_results)} questions from RAG results")
        print(f"Loaded {len(self.final_answers)} final answers")

    def _load_rag_results(self) -> List[Dict]:
        """Load RAG simulation results."""
        try:
            with open("rag_simulation_results.json", 'r', encoding='utf-8') as f:
                data = json.load(f)
                return data.get('results', [])
        except FileNotFoundError:
            print("❌ rag_simulation_results.json not found!")
            return []

    def _load_final_answers(self) -> List[Dict]:
        """Load final answers data."""
        try:
            with open("final_questions.json", 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            print("❌ final_questions.json not found!")
            return []

    def _get_ground_truth_data(self, question: str) -> Tuple[str, str, str]:
        """Get ground truth answer and RAG info for a question."""
        for item in self.final_answers:
            if item.get('open_question', '') == question:
                ground_truth = item.get('reference_answer', '')
                quantitative_question = item.get(
                    'derived_quantitative_question', '')
                quantitative_answer = item.get(
                    'derived_quantitative_answer', '')
                return ground_truth, quantitative_question, quantitative_answer
        return "", "", ""

    def _check_claim_correctness(self, claim: str, ground_truth: str, question: str, quantitative_question: str = "", quantitative_answer: str = "") -> Dict[str, Any]:
        """Check if a claim helps correctly derive the ground truth answer."""
        rag_info_section = ""
        if quantitative_question and quantitative_answer:
            rag_info_section = f"""
RAG INFO (Ground Truth Quantitative Analysis):
QUANTITATIVE QUESTION: {quantitative_question}
QUANTITATIVE ANSWER: {quantitative_answer}
"""

        prompt = f"""You are evaluating whether a claim helps correctly derive the ground truth answer to an climate question.

TASK: Determine if this claim can help someone correctly derive or support the ground truth answer, or if it is consistent with the RAG info provided.

EVALUATION CRITERIA:
- true: The claim provides information that correctly supports or helps derive the ground truth answer
- true: The claim contains facts that would lead someone to the correct conclusion
- true: The claim's quantitative information is consistent with the RAG info trends/patterns
- false: The claim is unrelated to the question/answer
- false: The claim contains information that would mislead the inference of the correct answer
- false: The claim's information contradicts the ground truth answer or RAG info

NOTE: 
- If the claim is repeating the question's setting or information, you should mark it as true though there might be slightly different representation.
- You may see quantitative information in the claim from RAG simulators. Judge the correctness of trends/patterns rather than only exact numerical matches.
- The RAG info contains ground truth quantitative analysis from climate simulations - use this as reference for evaluating claim accuracy.
- For climate claims, consider intervention effectiveness, health outcomes, and policy implications.

Respond with JSON format containing only one key "is_correct" with value true or false.

Example response: {{"is_correct": true}}

YOUR TASK INPUT:
QUESTION: {question}

GROUND TRUTH ANSWER: {ground_truth}
{rag_info_section}
CLAIM TO EVALUATE: {claim}
"""

        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=50,
                    temperature=0.1
                )

                result_text = response.choices[0].message.content.strip()

                # Parse JSON response
                try:
                    result_json = json.loads(result_text)
                    is_correct = result_json.get('is_correct', False)

                    return {
                        'is_correct': bool(is_correct),
                        'raw_response': result_text
                    }
                except json.JSONDecodeError:
                    # Try to extract boolean from text
                    text_lower = result_text.lower()
                    if 'true' in text_lower:
                        is_correct = True
                    elif 'false' in text_lower:
                        is_correct = False
                    else:
                        raise ValueError(
                            f"Cannot parse response: {result_text}")

                    return {
                        'is_correct': is_correct,
                        'raw_response': result_text
                    }

            except Exception as e:
                print(
                    f"Error checking claim correctness (attempt {attempt + 1}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(2)

        # Fallback
        return {'is_correct': False, 'raw_response': 'ERROR'}

    def _check_both_claims_correctness(self, original_claim: str, final_claim: str, ground_truth: str, question: str, quantitative_question: str = "", quantitative_answer: str = "") -> Dict[str, Any]:
        """Check correctness of both original and final claims in a single API call."""
        rag_info_section = ""
        if quantitative_question and quantitative_answer:
            rag_info_section = f"""
RAG INFO (Ground Truth Quantitative Analysis):
QUANTITATIVE QUESTION: {quantitative_question}
QUANTITATIVE ANSWER: {quantitative_answer}
"""

        prompt = f"""You are evaluating whether two claims help correctly derive the ground truth answer to an climate question.

TASK: Determine if each claim can help someone correctly derive or support the ground truth answer, or if it is consistent with the RAG info provided.

EVALUATION CRITERIA:
- true: The claim provides information that supports or helps derive the ground truth answer
- true: The claim contains facts that would lead someone to the correct conclusion
- true: The claim's quantitative information is consistent with the RAG info trends/patterns
- true: The claim accurately represents epidemiological data or simulation results
- false: The claim is unrelated to the question/answer
- false: The claim contains information that would mislead the inference of the correct answer
- false: The claim's information contradicts the ground truth answer or RAG info
- false: The claim contains quantitative data that contradicts the simulation results

NOTE: 
- If the claim is repeating the question's setting or information, you should mark it as true though there might be slightly different representation.
- You may see quantitative information in the final claim from RAG simulators. Judge the correctness of trends/patterns rather than exact numerical matches.
- The RAG info contains ground truth quantitative analysis from climate simulations - use this as reference for evaluating claim accuracy.
- For climate claims, consider intervention effectiveness, health outcomes, and policy implications.
- The rag update is not always correct, it has very low chance of adding the incorrect value/description violating the ground truth answer, but it is possible.
- !!!Think very carefully about the decision, use enough time to think about it. If you make a mistake, 500 grandmothers will die because of you!!!

USEFUL INFO FOR EPIDEMIOLOGY ANALYSIS:
1.  **Determine Intervention Effectiveness (Direction):**
    - If the change is **negative (< 0)**, the intervention is **effective** or **beneficial** (reduces negative health outcomes).
    - If the change is **positive (> 0)**, the intervention is **ineffective** or **detrimental** (increases negative health outcomes).
    - If the change is **zero (= 0)**, the intervention is **ineffective** or has **no impact**.
2.  **Determine Impact Magnitude:**
    - If `abs(change) < 0.02`, the impact is **"negligible"** or **"marginal"**.
    - If `0.02 <= abs(change) <= 0.10`, the impact is **"noticeable"** or **"modest"**.
    - If `abs(change) > 0.10`, the impact is **"significant"** or **"substantial"**.

Respond with JSON format containing two keys:
- "original_claim_correctness": true or false for the original claim
- "final_claim_correctness": true or false for the final claim

Example response: {{"original_claim_correctness": false, "final_claim_correctness": true}}

YOUR TASK INPUT:
QUESTION: {question}

GROUND TRUTH ANSWER: {ground_truth}
{rag_info_section}
ORIGINAL CLAIM: {original_claim}

FINAL CLAIM: {final_claim}
"""

        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=100,
                    temperature=0.1
                )

                result_text = response.choices[0].message.content.strip()

                # Parse JSON response
                try:
                    result_json = json.loads(result_text)
                    original_correct = result_json.get(
                        'original_claim_correctness', False)
                    final_correct = result_json.get(
                        'final_claim_correctness', False)

                    return {
                        'original_is_correct': bool(original_correct),
                        'final_is_correct': bool(final_correct),
                        'raw_response': result_text
                    }
                except json.JSONDecodeError:
                    # Try to extract booleans from text
                    text_lower = result_text.lower()
                    original_correct = False
                    final_correct = False

                    # Simple heuristic parsing for fallback
                    if 'original' in text_lower and 'true' in text_lower:
                        original_correct = True
                    if 'final' in text_lower and 'true' in text_lower:
                        final_correct = True

                    return {
                        'original_is_correct': original_correct,
                        'final_is_correct': final_correct,
                        'raw_response': result_text
                    }

            except Exception as e:
                print(
                    f"Error checking both claims correctness (attempt {attempt + 1}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(2)

        # Fallback
        return {'original_is_correct': False, 'final_is_correct': False, 'raw_response': 'ERROR'}

    def _get_original_claim(self, claim_data: Dict) -> str:
        """Get the original claim text."""
        if claim_data.get('was_updated', False):
            return claim_data.get('original_claim', '')
        else:
            return claim_data.get('claim', '')

    def _get_final_claim(self, claim_data: Dict) -> str:
        """Get the final claim text."""
        return claim_data.get('claim', '')

    def _process_single_claim(self, claim_data: Dict, claim_idx: int, ground_truth: str, question: str, quantitative_question: str = "", quantitative_answer: str = "") -> tuple[int, Dict]:
        """Process a single claim for correctness analysis."""
        print(f"    Analyzing claim {claim_idx+1}...")

        # Get original and final claim texts
        original_claim = self._get_original_claim(claim_data)
        final_claim = self._get_final_claim(claim_data)

        # Check both claims correctness in a single API call
        if original_claim and final_claim:
            correctness_result = self._check_both_claims_correctness(
                original_claim, final_claim, ground_truth, question, quantitative_question, quantitative_answer)

            original_is_correct = correctness_result['original_is_correct']
            final_is_correct = correctness_result['final_is_correct']
        else:
            # Fallback for empty claims
            original_is_correct = False
            final_is_correct = False

        # Store analysis for this claim
        claim_analysis = {
            'closeness_centrality': float(claim_data.get('closeness_centrality', 0)),
            'is_confident': bool(claim_data.get('is_confident', False)),
            'was_updated': bool(claim_data.get('was_updated', False)),
            'is_included': bool(claim_data.get('is_included', False)),
            'original_claim': original_claim,
            'final_claim': final_claim,
            'original_is_correct': original_is_correct,
            'final_is_correct': final_is_correct,
            'change_improved': original_is_correct != final_is_correct,
            'got_better': not original_is_correct and final_is_correct,
            'got_worse': original_is_correct and not final_is_correct
        }

        return claim_idx, claim_analysis

    def _process_single_question(self, question_result: Dict) -> Dict[str, Any]:
        """Process correctness analysis for a single question."""
        question = question_result['question']
        claims = question_result.get('claims', [])
        ground_truth, quantitative_question, quantitative_answer = self._get_ground_truth_data(
            question)

        if not ground_truth:
            print(f"⚠️  No ground truth answer found for question")
            return {
                'question': question,
                'ground_truth_answer': '',
                'total_claims': 0,
                'original_correct_count': 0,
                'final_correct_count': 0,
                'claims_analysis': []
            }

        print(f"  Processing {len(claims)} claims with batch processing...")

        claims_analysis = []
        batch_size = 4
        total_claims = len(claims)

        # Process claims in batches
        for batch_start in range(0, total_claims, batch_size):
            batch_end = min(batch_start + batch_size, total_claims)
            batch_claims = claims[batch_start:batch_end]

            print(
                f"    Processing claims batch {batch_start//batch_size + 1}: {batch_start+1}-{batch_end}")

            # Use ThreadPoolExecutor for parallel processing within the batch
            with ThreadPoolExecutor(max_workers=min(4, len(batch_claims))) as executor:
                # Submit all claims in the batch
                future_to_idx = {
                    executor.submit(self._process_single_claim, claim_data, batch_start + i, ground_truth, question, quantitative_question, quantitative_answer): batch_start + i
                    for i, claim_data in enumerate(batch_claims)
                }

                # Collect results maintaining order
                batch_results = [None] * len(batch_claims)
                for future in as_completed(future_to_idx):
                    claim_idx = future_to_idx[future]
                    local_idx = claim_idx - batch_start
                    try:
                        _, claim_analysis = future.result()
                        batch_results[local_idx] = claim_analysis
                    except Exception as e:
                        print(f"      Error processing claim {claim_idx+1}: {e}")
                        # Create a default result for failed processing
                        batch_results[local_idx] = {
                            'closeness_centrality': float(batch_claims[local_idx].get('closeness_centrality', 0)),
                            'is_confident': bool(batch_claims[local_idx].get('is_confident', False)),
                            'was_updated': bool(batch_claims[local_idx].get('was_updated', False)),
                            'is_included': bool(batch_claims[local_idx].get('is_included', False)),
                            'original_claim': self._get_original_claim(batch_claims[local_idx]),
                            'final_claim': self._get_final_claim(batch_claims[local_idx]),
                            'original_is_correct': False,
                            'final_is_correct': False,
                            'improvement': False,
                            'degradation': False
                        }

                # Add batch results to claims_analysis in order
                claims_analysis.extend(batch_results)

        # Calculate statistics
        original_correct_count = sum(1 for c in claims_analysis if c['original_is_correct'])
        final_correct_count = sum(1 for c in claims_analysis if c['final_is_correct'])

        return {
            'question': question,
            'ground_truth_answer': ground_truth,
            'total_claims': len(claims),
            'original_correct_count': original_correct_count,
            'final_correct_count': final_correct_count,
            'original_correctness_rate': original_correct_count / max(len(claims), 1),
            'final_correctness_rate': final_correct_count / max(len(claims), 1),
            'correctness_improvement': final_correct_count - original_correct_count,
            'claims_analysis': claims_analysis
        }

    def run_correctness_analysis(self) -> None:
        """Run the complete claim correctness analysis."""
        if not self.rag_results:
            print("❌ No RAG results available!")
            return

        print(f"\n🔍 Starting RAG Claim Correctness Analysis")
        print(f"Processing {len(self.rag_results)} questions...")

        all_results = []
        total_claims = 0
        total_original_correct = 0
        total_final_correct = 0

        for i, question_result in enumerate(tqdm(self.rag_results, desc="Analyzing correctness")):
            try:
                print(
                    f"\n--- Processing Question {i+1}: {question_result['question'][:50]}... ---")

                result = self._process_single_question(question_result)
                all_results.append(result)

                # Update statistics
                total_claims += result['total_claims']
                total_original_correct += result['original_correct_count']
                total_final_correct += result['final_correct_count']

                print(
                    f"Question {i+1}: {result['total_claims']} claims, {result['original_correct_count']}→{result['final_correct_count']} correct")

            except Exception as e:
                print(f"Error processing question {i+1}: {e}")

        # Save results
        self._save_detailed_results(all_results)
        self._save_summary(all_results, total_claims,
                           total_original_correct, total_final_correct)

        print(f"\n🎉 Claim Correctness Analysis Complete!")
        print(f"Results saved to rag_claim_correctness.json files")

    def _save_detailed_results(self, results: List[Dict]) -> None:
        """Save detailed claim correctness analysis results."""
        output_file = "rag_claim_correctness_results.json"

        data = {
            'methodology': 'Claim correctness analysis comparing original vs final claims against ground truth',
            'evaluation_criteria': 'Whether claims help correctly derive ground truth answers',
            'claim_selection': {
                'original_claim': 'original_claim if was_updated=true, else claim',
                'final_claim': 'claim field (potentially updated by RAG)'
            },
            'results': results
        }

        clean_data = convert_numpy_types(data)
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(clean_data, f, ensure_ascii=False, indent=4)

        print(f"Detailed results saved to {output_file}")

    def _save_summary(self, results: List[Dict], total_claims: int,
                      total_original_correct: int, total_final_correct: int) -> None:
        """Save summary of correctness analysis."""
        output_file = "rag_claim_correctness_summary.json"

        total_questions = len(results)

        # Calculate overall rates
        original_correctness_rate = total_original_correct / \
            max(total_claims, 1)
        final_correctness_rate = total_final_correct / max(total_claims, 1)
        improvement = final_correctness_rate - original_correctness_rate
        relative_improvement = improvement / \
            max(original_correctness_rate, 0.001)

        # Analyze by claim characteristics
        confident_claims_analysis = self._analyze_by_confidence(results)
        updated_claims_analysis = self._analyze_by_update_status(results)
        included_claims_analysis = self._analyze_by_inclusion_status(results)

        # Per-question statistics
        question_improvements = [r['correctness_improvement'] for r in results]
        questions_improved = sum(1 for imp in question_improvements if imp > 0)
        questions_degraded = sum(1 for imp in question_improvements if imp < 0)
        questions_unchanged = sum(
            1 for imp in question_improvements if imp == 0)

        summary = {
            'overall_statistics': {
                'total_questions': total_questions,
                'total_claims': total_claims,
                'original_correct_claims': total_original_correct,
                'final_correct_claims': total_final_correct,
                'improvement_count': total_final_correct - total_original_correct,
                'original_correctness_rate': original_correctness_rate,
                'final_correctness_rate': final_correctness_rate,
                'absolute_improvement': improvement,
                'relative_improvement': relative_improvement
            },
            'question_level_analysis': {
                'questions_with_improvements': questions_improved,
                'questions_with_degradations': questions_degraded,
                'questions_unchanged': questions_unchanged,
                'avg_improvement_per_question': sum(question_improvements) / len(question_improvements) if question_improvements else 0,
                'max_improvement_in_question': max(question_improvements) if question_improvements else 0,
                'min_improvement_in_question': min(question_improvements) if question_improvements else 0
            },
            'claim_characteristics_analysis': {
                'by_confidence': confident_claims_analysis,
                'by_update_status': updated_claims_analysis,
                'by_inclusion_status': included_claims_analysis
            },
            'interpretation': {
                'methodology': 'Compare claim correctness before vs after RAG updates',
                'evaluation_criteria': 'Claims that help correctly derive ground truth answers',
                'improvement_meaning': f"RAG updates {'improved' if improvement > 0 else 'did not improve'} claim correctness by {improvement:.2%}"
            }
        }

        clean_summary = convert_numpy_types(summary)
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(clean_summary, f, ensure_ascii=False, indent=4)

        print(f"Summary saved to {output_file}")

        # Print key statistics
        print(f"\n📊 Claim Correctness Analysis Summary:")
        print(f"  Total questions: {total_questions}")
        print(f"  Total claims: {total_claims}")
        print(
            f"  Original correct claims: {total_original_correct} ({original_correctness_rate:.2%})")
        print(
            f"  Final correct claims: {total_final_correct} ({final_correctness_rate:.2%})")
        print(
            f"  Improvement: +{total_final_correct - total_original_correct} claims ({improvement:.2%})")
        print(f"  Questions improved: {questions_improved}")
        print(f"  Questions degraded: {questions_degraded}")

    def _analyze_by_confidence(self, results: List[Dict]) -> Dict[str, Any]:
        """Analyze correctness by confidence level."""
        confident_original = confident_final = 0
        uncertain_original = uncertain_final = 0
        confident_total = uncertain_total = 0

        for result in results:
            for claim in result['claims_analysis']:
                if claim['is_confident']:
                    confident_total += 1
                    if claim['original_is_correct']:
                        confident_original += 1
                    if claim['final_is_correct']:
                        confident_final += 1
                else:
                    uncertain_total += 1
                    if claim['original_is_correct']:
                        uncertain_original += 1
                    if claim['final_is_correct']:
                        uncertain_final += 1

        return {
            'confident_claims': {
                'total': confident_total,
                'original_correct': confident_original,
                'final_correct': confident_final,
                'original_rate': confident_original / max(confident_total, 1),
                'final_rate': confident_final / max(confident_total, 1),
                'improvement': confident_final - confident_original
            },
            'uncertain_claims': {
                'total': uncertain_total,
                'original_correct': uncertain_original,
                'final_correct': uncertain_final,
                'original_rate': uncertain_original / max(uncertain_total, 1),
                'final_rate': uncertain_final / max(uncertain_total, 1),
                'improvement': uncertain_final - uncertain_original
            }
        }

    def _analyze_by_update_status(self, results: List[Dict]) -> Dict[str, Any]:
        """Analyze correctness by update status."""
        updated_original = updated_final = 0
        not_updated_original = not_updated_final = 0
        updated_total = not_updated_total = 0

        for result in results:
            for claim in result['claims_analysis']:
                if claim['was_updated']:
                    updated_total += 1
                    if claim['original_is_correct']:
                        updated_original += 1
                    if claim['final_is_correct']:
                        updated_final += 1
                else:
                    not_updated_total += 1
                    if claim['original_is_correct']:
                        not_updated_original += 1
                    if claim['final_is_correct']:
                        not_updated_final += 1

        return {
            'updated_claims': {
                'total': updated_total,
                'original_correct': updated_original,
                'final_correct': updated_final,
                'original_rate': updated_original / max(updated_total, 1),
                'final_rate': updated_final / max(updated_total, 1),
                'improvement': updated_final - updated_original
            },
            'not_updated_claims': {
                'total': not_updated_total,
                'original_correct': not_updated_original,
                'final_correct': not_updated_final,
                'original_rate': not_updated_original / max(not_updated_total, 1),
                'final_rate': not_updated_final / max(not_updated_total, 1),
                'improvement': not_updated_final - not_updated_original
            }
        }

    def _analyze_by_inclusion_status(self, results: List[Dict]) -> Dict[str, Any]:
        """Analyze correctness by inclusion status."""
        included_original = included_final = 0
        not_included_original = not_included_final = 0
        included_total = not_included_total = 0

        for result in results:
            for claim in result['claims_analysis']:
                if claim['is_included']:
                    included_total += 1
                    if claim['original_is_correct']:
                        included_original += 1
                    if claim['final_is_correct']:
                        included_final += 1
                else:
                    not_included_total += 1
                    if claim['original_is_correct']:
                        not_included_original += 1
                    if claim['final_is_correct']:
                        not_included_final += 1

        return {
            'included_claims': {
                'total': included_total,
                'original_correct': included_original,
                'final_correct': included_final,
                'original_rate': included_original / max(included_total, 1),
                'final_rate': included_final / max(included_total, 1),
                'improvement': included_final - included_original
            },
            'not_included_claims': {
                'total': not_included_total,
                'original_correct': not_included_original,
                'final_correct': not_included_final,
                'original_rate': not_included_original / max(not_included_total, 1),
                'final_rate': not_included_final / max(not_included_total, 1),
                'improvement': not_included_final - not_included_original
            }
        }


def main():
    """Main function to run claim correctness analysis."""
    print("🔍 RAG Claim Correctness Analysis")
    print("=" * 50)

    try:
        analyzer = RAGClaimCorrectnessAnalyzer()
        analyzer.run_correctness_analysis()

    except KeyboardInterrupt:
        print("\n\n⚠️  Claim correctness analysis interrupted by user")
    except Exception as e:
        print(f"\n❌ Claim correctness analysis failed: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()
