#!/usr/bin/env python3
"""
Ground truth claim analysis script for Epidemiology2 domain.
Decomposes reference answers into claims and applies our method for final prediction.
"""

import os
import json
import numpy as np
from typing import Dict, List, Any, Tuple
from uncertainty.claim_decomposer import ClaimDecomposer


class GroundTruthClaimAnalysis:
    """Ground truth claim analysis for selected questions."""

    def __init__(self):
        """Initialize the analysis."""
        self.claim_decomposer = ClaimDecomposer()

        # Load required data
        self.final_answers = self._load_json('final_questions.json')
        self.uncertainty_results = self._load_json('uncertainty_analysis_results.json')
        self.rag_correctness = self._load_json('rag_claim_correctness_results.json')

        print("GroundTruthClaimAnalysis initialized for Epidemiology2 domain")

    def _load_json(self, filepath: str) -> List[Dict]:
        """Load JSON file safely."""
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            print(f"Warning: {filepath} not found")
            return []
        except json.JSONDecodeError as e:
            print(f"Error decoding {filepath}: {e}")
            return []

    def decompose_reference_answer(self, reference_answer: str) -> List[str]:
        """Decompose reference answer into claims."""
        try:
            claims_data = self.claim_decomposer.decompose_response(reference_answer)
            return [claim['claim'] for claim in claims_data]
        except Exception as e:
            print(f"Error decomposing reference answer: {e}")
            return []

    def _extract_claims_with_uncertainty_and_correctness(self, question: str) -> List[Dict]:
        """Extract claims with both uncertainty metrics and correctness data for a specific question."""
        claims_data = []

        # Find corresponding uncertainty data
        uncertainty_question = None
        for uq in self.uncertainty_results:
            if uq.get('question') == question:
                uncertainty_question = uq
                break

        if not uncertainty_question:
            return []

        # Find corresponding correctness data
        correctness_question = None
        if self.rag_correctness and 'results' in self.rag_correctness:
            for cq in self.rag_correctness['results']:
                if cq.get('question') == question:
                    correctness_question = cq
                    break

        if not correctness_question:
            return []

        claim_uncertainties = uncertainty_question.get('claim_uncertainties', [])
        claims_analysis = correctness_question.get('claims_analysis', [])

        # Match claims between correctness and uncertainty data
        for claim_correctness in claims_analysis:
            original_claim = claim_correctness.get('original_claim', '')

            for claim_uncertainty in claim_uncertainties:
                uncertainty_claim = claim_uncertainty.get('claim', '')

                if original_claim and uncertainty_claim and (
                    original_claim == uncertainty_claim or
                    original_claim in uncertainty_claim or
                    uncertainty_claim in original_claim
                ):
                    combined_claim = {
                        'original_claim': original_claim,
                        'final_claim': claim_correctness.get('final_claim', original_claim),
                        'claim': original_claim,  # Keep for backward compatibility
                        'original_is_correct': claim_correctness.get('original_is_correct', False),
                        'final_is_correct': claim_correctness.get('final_is_correct', False),
                        'was_updated': claim_correctness.get('was_updated', False),
                        'is_included': claim_correctness.get('is_included', False),
                        'uncertainty_metrics': claim_uncertainty.get('uncertainty_metrics', {}),
                        'closeness_centrality': claim_uncertainty.get('uncertainty_metrics', {}).get('closeness_centrality', 0.0),
                        'tool_confidence': claim_uncertainty.get('uncertainty_metrics', {}).get('tool_confidence', 0.0)
                    }
                    claims_data.append(combined_claim)
                    break

        return claims_data

    def _calculate_threshold_for_rag_rate(self, claims_data: List[Dict], target_rate: float) -> float:
        """Calculate uncertainty threshold to achieve target RAG rate."""
        metric_values = [
            claim['uncertainty_metrics'].get('closeness_centrality', 0.0)
            for claim in claims_data
            if 'uncertainty_metrics' in claim
        ]

        if not metric_values:
            return 0.5

        sorted_values = sorted(metric_values)
        threshold_index = int(len(sorted_values) * target_rate)
        threshold_index = min(threshold_index, len(sorted_values) - 1)

        return sorted_values[threshold_index]

    def _apply_ours_method(self, claims_data: List[Dict], target_rate: float = 0.45) -> Tuple[float, float, List[Dict], List[str]]:
        """
        Apply 'ours' method: tool filter + uncertainty selection at 45% RAG rate.
        Returns: (rag_threshold, filter_threshold, selected_claims, filtered_claims)
        """
        if not claims_data:
            return 0.0, 0.0, [], []

        # Step 1: Filter claims with tool_confidence = 1 (scientific boundary filter)
        tool_filtered_claims = []
        claim_indices = []

        for i, claim in enumerate(claims_data):
            tool_conf = claim['uncertainty_metrics'].get('tool_confidence', 0.0)
            if tool_conf == 1.0:
                tool_filtered_claims.append(claim)
                claim_indices.append(i)

        if not tool_filtered_claims:
            # If no tool-verifiable claims, fall back to uncertainty-only method
            threshold = self._calculate_threshold_for_rag_rate(claims_data, target_rate)
            # In fallback case, apply same logic to all claims
            selected_claims = []
            filtered_claims = []

            # Create set of RAG'ed and included claims
            ragged_included_claims = set()
            for claim in claims_data:
                uncertainty_value = claim['uncertainty_metrics'].get('closeness_centrality', 0.0)
                is_included = claim.get('is_included', False)
                if uncertainty_value < threshold and is_included:
                    ragged_included_claims.add(claim['original_claim'])

            # Process all claims
            for claim in claims_data:
                # Determine prediction value
                if claim['original_claim'] in ragged_included_claims:
                    prediction = 1.0
                    is_included = True
                else:
                    prediction = claim['closeness_centrality']
                    is_included = False

                claim_data = {
                    "claim": claim['final_claim'],
                    "original_claim": claim['original_claim'],
                    "final_is_correct": claim['final_is_correct'],
                    "closeness_centrality": claim['closeness_centrality'],
                    "tool_confidence": claim['tool_confidence'],
                    "updated": claim.get('was_updated', False),
                    "is_included": is_included,
                    "prediction": prediction
                }

                # Selected claims: prediction >= 0.42
                if prediction >= 0.42:
                    selected_claims.append(claim_data)

                # Filtered claims: prediction < 0.42 (filtered out claims)
                if prediction < 0.42:
                    filtered_claims.append({
                        "claim": claim['final_claim'],
                        "original_claim": claim['original_claim'],
                        "final_is_correct": claim['final_is_correct'],
                        "closeness_centrality": claim['closeness_centrality'],
                        "tool_confidence": claim['tool_confidence']
                    })
            # Use fixed filter threshold even in fallback case
            return threshold, 0.42, selected_claims, filtered_claims

        # Step 2: Calculate target RAG count based on ALL claims
        total_target_rag_count = int(len(claims_data) * target_rate)

        # Step 3: Sort tool-filtered claims by uncertainty (most uncertain first)
        claim_index_pairs = list(zip(tool_filtered_claims, claim_indices))
        claim_index_pairs.sort(key=lambda x: x[0]['uncertainty_metrics'].get('closeness_centrality', 1.0))

        # Step 4: Select top N most uncertain tool-verified claims
        actual_rag_count = min(total_target_rag_count, len(tool_filtered_claims))
        selected_pairs = claim_index_pairs[:actual_rag_count]

        # Calculate thresholds
        if selected_pairs:
            rag_threshold = selected_pairs[-1][0]['uncertainty_metrics'].get('closeness_centrality', 0.0)
        else:
            rag_threshold = 0.0

        # Use fixed filter threshold as specified
        filter_threshold = 0.42

        # Step 5: Calculate predictions for all claims and apply filtering
        selected_claims = []
        filtered_claims = []

        # First, create a set of RAG'ed and included claims for quick lookup
        ragged_included_claims = set()
        for claim, _ in selected_pairs:
            if claim.get('is_included', False):
                ragged_included_claims.add(claim['original_claim'])

        # Process all claims to determine predictions and apply filtering
        for claim in claims_data:
            # Determine prediction value
            if claim['original_claim'] in ragged_included_claims:
                prediction = 1.0
                is_included = True
            else:
                prediction = claim['closeness_centrality']
                is_included = False

            claim_data = {
                "claim": claim['final_claim'],
                "original_claim": claim['original_claim'],
                "final_is_correct": claim['final_is_correct'],
                "closeness_centrality": claim['closeness_centrality'],
                "tool_confidence": claim['tool_confidence'],
                "updated": claim.get('was_updated', False),
                "is_included": is_included,
                "prediction": prediction
            }

            # Selected claims: prediction >= 0.42
            if prediction >= 0.42:
                selected_claims.append(claim_data)

            # Filtered claims: prediction < 0.42 (filtered out claims)
            if prediction < 0.42:
                filtered_claims.append({
                    "claim": claim['final_claim'],
                    "original_claim": claim['original_claim'],
                    "final_is_correct": claim['final_is_correct'],
                    "closeness_centrality": claim['closeness_centrality'],
                    "tool_confidence": claim['tool_confidence']
                })

        return rag_threshold, filter_threshold, selected_claims, filtered_claims

    def process_selected_questions(self, selected_questions: List[int]) -> List[Dict]:
        """Process selected questions and return results."""
        results = []

        for question_idx in selected_questions:
            print(f"\nProcessing question {question_idx}...")

            # Get question data
            if question_idx >= len(self.final_answers):
                print(f"Warning: Question {question_idx} not found in final_answers")
                continue

            question_data = self.final_answers[question_idx]
            open_question = question_data.get('open_question', '')
            reference_answer = question_data.get('reference_answer', '')

            if not reference_answer:
                print(f"Warning: No reference answer for question {question_idx}")
                continue

            # Decompose reference answer into claims
            reference_claims = self.decompose_reference_answer(reference_answer)

            # Extract claims with uncertainty and correctness data
            claims_data = self._extract_claims_with_uncertainty_and_correctness(open_question)

            if not claims_data:
                print(f"Warning: No matching claims found for question {question_idx}")
                continue

            # Apply our method
            rag_threshold, filter_threshold, selected_claims, filtered_claims = self._apply_ours_method(claims_data)

            # Prepare result
            result = {
                "open_question": open_question,
                "reference_answer": reference_answer,
                "reference_answer_claims": reference_claims,
                "our_method_claims": {
                    "rag_threshold": round(rag_threshold, 6),
                    "filter_threshold": round(filter_threshold, 6),
                    "selected_claims": selected_claims,
                    "filtered_claims": filtered_claims
                }
            }

            results.append(result)
            print(f"Question {question_idx} processed successfully")
            print(f"  Reference claims: {len(reference_claims)}")
            print(f"  Selected claims: {len(selected_claims)}")
            print(f"  Filtered claims: {len(filtered_claims)}")

        return results

    def run_analysis(self, output_file: str = "final_answers_claim.json"):
        """Run the ground truth claim analysis."""
        # Selected questions as specified
        selected_questions = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 199]

        print(f"Starting ground truth claim analysis for questions: {selected_questions}")

        # Process selected questions
        results = self.process_selected_questions(selected_questions)

        # Save results
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        print(f"\nAnalysis complete! Results saved to {output_file}")
        print(f"Processed {len(results)} questions successfully")


def main():
    """Main function to run ground truth claim analysis."""
    analyzer = GroundTruthClaimAnalysis()
    analyzer.run_analysis()


if __name__ == "__main__":
    main()