#!/usr/bin/env python3
"""
Tools Confidence Analysis for Epidemiology Claims
Evaluates whether claims can be verified using available epidemiology research tools.
"""

import os
import json
import time
from typing import Dict, List, Any
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
import copy
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

load_dotenv()


class ToolsConfidenceAnalyzer:
    """Analyze epidemiology claims for tool-based verification confidence."""

    def __init__(self):
        """Initialize the tools confidence analyzer."""
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

        # Load uncertainty analysis results
        self.uncertainty_results = self._load_uncertainty_results()

        print(f"Initialized ToolsConfidenceAnalyzer")
        print(
            f"Loaded {len(self.uncertainty_results)} questions from uncertainty analysis")

    def _load_uncertainty_results(self) -> List[Dict]:
        """Load uncertainty analysis results."""
        try:
            with open("uncertainty_analysis_results.json", 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            print("❌ uncertainty_analysis_results.json not found!")
            return []

    def _get_tools_prompt(self) -> str:
        """Get the detailed tools description prompt."""
        return """
You are an expert in epidemiology and public health computational tools. You will evaluate whether epidemiology-related claims can be verified using available epidemiology simulation tools.

**AVAILABLE EPIDEMIOLOGY TOOLS:**

1. **GLEAM-AI Influenza Forecasting Simulator**
   - Purpose: Models influenza epidemic dynamics across US states over multiple months (28+ weeks) with focus on seasonal flu patterns and comprehensive epidemiological forecasting
   - Input: 
     * r0_value (float): Basic reproduction number (1.75-3.0)
     * seasonality_level (string): Seasonality pattern ("no seasonality", "moderate seasonality", "high seasonality")
     * starting_date (string): Simulation start date in YYYY-MM-DD format (typically September-November for flu season)
     * prior_immunity_level (string): Prior immunity percentage (e.g., "20%", "40%", "60%")
     * target_states (list): US states for analysis (2-5 states from available US states)
     * target_metric (string): Single health metric to analyze (focus on hospitalizations)
   - Output: Comprehensive structured outlook including:
     * Peak magnitude with 90% confidence intervals
     * Peak timing (date and weeks from start)
     * Initial trend (qualitative and quantitative growth rates)
     * Total burden (cumulative values over season)
     * Uncertainty quantification across multiple stochastic runs
   - Use case: Seasonal influenza forecasting, CDC FluSight target analysis, hospital capacity planning, epidemic trajectory prediction

**AVAILABLE INPUT PARAMETERS:**
- r0_value: Basic reproduction number indicating influenza transmissibility (1.75-3.0)
- seasonality_level: Seasonal pattern strength affecting transmission dynamics
- starting_date: Flu season start date (September-November timeframe)
- prior_immunity_level: Prior population immunity percentage (10%-60%)
- target_states: Geographic scope of analysis (US states)
- target_metric: Single health outcome to analyze

**AVAILABLE TARGET METRICS:**
- hospital incidence: New hospitalizations per time period
- hospital prevalence: Current hospitalizations at any given time
- latent incidence: New infections (including pre-symptomatic/asymptomatic)

**SUPPORTED ANALYSIS CAPABILITIES:**
- Seasonal flu peak timing and magnitude prediction
- Multi-state epidemic modeling across US states
- CDC FluSight target analysis (rate trends, peak timing, peak height, season end, threshold achievement)
- Initial growth rate analysis (4-week trends)
- Hospital capacity forecasting and planning
- Uncertainty quantification through multiple stochastic simulation runs
- Comparative seasonality impact assessment

**EVALUATION TASK:**
For each claim provided, you need to determine how well the available tools can help verify or assess the accuracy of that claim.

**SCORING CRITERIA:**
- **0**: The claim cannot be verified or assessed using any of the available epidemiology tools
  - Examples: Claims about non-epidemiology topics, vaccine effectiveness, behavioral factors, economic impacts without quantifiable health metrics, social aspects that cannot be directly modeled, claims requiring data/tools not available in GLEAM-AI
  
- **1**: The claim can be directly and comprehensively verified or assessed using the available epidemiology tools
  - Examples: Claims about seasonal flu patterns, peak timing and magnitude, R0 effects on disease spread, seasonality impact on transmission, hospital incidence/prevalence predictions, latent infection patterns, state-specific flu dynamics, CDC FluSight target outcomes

**EXAMPLES OF VERIFIABLE CLAIMS (Score 1):**
- "High seasonality flu scenarios with 40% prior immunity in Texas show hospital incidence peaks occurring 2 weeks earlier than moderate seasonality"
- "R0 = 2.8 influenza seasons with 30% prior immunity result in 40% higher peak hospitalizations compared to R0 = 2.0 scenarios"
- "California flu seasons starting in early October with 20% prior immunity show different peak timing patterns than those starting in late November"
- "Latent incidence peaks consistently occur 3-5 days before hospital incidence peaks across multiple states regardless of prior immunity level"
- "Scenarios with 60% prior immunity show suppressed outbreak patterns compared to 20% prior immunity scenarios"

**EXAMPLES OF NON-VERIFIABLE CLAIMS (Score 0):**
- "Vaccine effectiveness varies by age group during flu season" (requires vaccine efficacy data beyond simulation)
- "Public awareness campaigns increase flu vaccination rates" (requires behavioral data)
- "Healthcare worker shortages impact hospital capacity" (requires staffing data beyond epidemiological modeling)
- "Economic burden of influenza exceeds respiratory syncytial virus" (requires economic analysis beyond health outcomes)

**RESPONSE FORMAT:**
Respond with ONLY a JSON object containing a single key "tool_confidence" with a value of 0 or 1.

Example: {{"tool_confidence": 1}}

**YOUR TASK:**
Question: {question}
Claim: {claim}

Evaluate how well the available epidemiology simulation tools can verify or assess this specific claim in the context of the given question.
"""

    def _evaluate_claim_tool_confidence(self, question: str, claim: str) -> float:
        """Evaluate tool confidence for a single claim."""
        prompt = self._get_tools_prompt().format(question=question, claim=claim)

        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=100,
                    temperature=0.1
                )

                result_text = response.choices[0].message.content.strip()

                # Handle markdown code blocks
                if result_text.startswith('```json'):
                    result_text = result_text.replace(
                        '```json', '').replace('```', '').strip()
                elif result_text.startswith('```'):
                    result_text = result_text.replace('```', '').strip()

                result_json = json.loads(result_text)

                tool_confidence = result_json.get('tool_confidence', 0)

                # Validate the score
                if tool_confidence in [0, 1]:
                    return float(tool_confidence)
                else:
                    print(
                        f"Warning: Invalid tool_confidence score {tool_confidence}, using 0")
                    return 0.0

            except json.JSONDecodeError as e:
                print(f"JSON decode error on attempt {attempt + 1}: {e}")
                print(f"Raw response: {result_text}")
                if attempt < max_retries - 1:
                    time.sleep(2)
                else:
                    print("Failed to parse JSON after all retries, using 0")
                    return 0.0

            except Exception as e:
                print(f"API error on attempt {attempt + 1}: {e}")
                if attempt < max_retries - 1:
                    time.sleep(5)
                else:
                    print("Failed API call after all retries, using 0")
                    return 0.0

        return 0.0

    def _process_single_claim(self, question: str, claim_idx: int, claim_data: Dict) -> tuple[int, Dict]:
        """Process a single claim to add tool confidence."""
        updated_claim_data = copy.deepcopy(claim_data)
        claim = claim_data.get('claim', '')

        if not claim:
            print(f"  Skipping empty claim {claim_idx}")
            return claim_idx, updated_claim_data

        # Get tool confidence
        tool_confidence = self._evaluate_claim_tool_confidence(question, claim)

        # Add tool_confidence to uncertainty_metrics
        if 'uncertainty_metrics' not in updated_claim_data:
            updated_claim_data['uncertainty_metrics'] = {}

        updated_claim_data['uncertainty_metrics']['tool_confidence'] = float(
            tool_confidence)

        print(f"    Claim {claim_idx+1} tool confidence: {tool_confidence}")

        return claim_idx, updated_claim_data

    def _update_question_results(self, question_data: Dict) -> Dict:
        """Update a single question's results with tool confidence using multithreading."""
        updated_question = copy.deepcopy(question_data)
        question = updated_question.get('question', '')
        claim_uncertainties = updated_question.get('claim_uncertainties', [])

        print(f"\nProcessing question: {question[:100]}...")
        print(
            f"Processing {len(claim_uncertainties)} claims with multithreading (batch size: 4)")

        if not claim_uncertainties:
            return updated_question

        # Process claims in batches with multithreading
        batch_size = 4
        total_claims = len(claim_uncertainties)
        results_buffer = {}  # Store results by index to maintain order

        for batch_start in range(0, total_claims, batch_size):
            batch_end = min(batch_start + batch_size, total_claims)
            batch_claims = [(i, claim_uncertainties[i])
                            for i in range(batch_start, batch_end)]

            print(
                f"  Processing batch {batch_start//batch_size + 1}: Claims {batch_start + 1}-{batch_end}")

            # Use ThreadPoolExecutor for parallel processing
            with ThreadPoolExecutor(max_workers=min(4, len(batch_claims))) as executor:
                # Submit all claims in the batch
                future_to_idx = {
                    executor.submit(self._process_single_claim, question, idx, claim_data): idx
                    for idx, claim_data in batch_claims
                }

                # Collect results as they complete
                for future in as_completed(future_to_idx):
                    original_idx = future_to_idx[future]
                    try:
                        claim_idx, updated_claim_data = future.result()
                        results_buffer[claim_idx] = updated_claim_data
                    except Exception as e:
                        print(
                            f"    Error processing claim {original_idx + 1}: {e}")
                        # Keep original data if processing failed
                        results_buffer[original_idx] = claim_uncertainties[original_idx]

        # Update the question with processed results in the correct order
        for i in range(total_claims):
            if i in results_buffer:
                updated_question['claim_uncertainties'][i] = results_buffer[i]

        return updated_question

    def _save_updated_results(self, updated_results: List[Dict]):
        """Save updated results back to the JSON file."""
        # Create backup
        # backup_filename = f"uncertainty_analysis_results_backup_{int(time.time())}.json"
        # try:
        #     with open(backup_filename, 'w', encoding='utf-8') as f:
        #         json.dump(self.uncertainty_results, f,
        #                   ensure_ascii=False, indent=2)
        #     print(f"Created backup: {backup_filename}")
        # except Exception as e:
        #     print(f"Warning: Could not create backup: {e}")

        # Save updated results
        try:
            # Convert any numpy types before saving
            def convert_numpy_types(obj):
                """Convert numpy types to Python native types for JSON serialization."""
                import numpy as np
                if isinstance(obj, np.bool_):
                    return bool(obj)
                elif isinstance(obj, np.integer):
                    return int(obj)
                elif isinstance(obj, np.floating):
                    return float(obj)
                elif isinstance(obj, np.ndarray):
                    return obj.tolist()
                elif isinstance(obj, dict):
                    return {key: convert_numpy_types(value) for key, value in obj.items()}
                elif isinstance(obj, list):
                    return [convert_numpy_types(item) for item in obj]
                else:
                    return obj

            clean_results = convert_numpy_types(updated_results)

            with open("uncertainty_analysis_results.json", 'w', encoding='utf-8') as f:
                json.dump(clean_results, f, ensure_ascii=False, indent=2)
            print("Successfully updated uncertainty_analysis_results.json")
        except Exception as e:
            print(f"Error saving updated results: {e}")
            import traceback
            traceback.print_exc()
            raise

    def analyze_tools_confidence(self):
        """Main method to analyze tool confidence for all claims."""
        print(f"\n=== Starting Tools Confidence Analysis ===")
        print(f"Total questions to process: {len(self.uncertainty_results)}")

        updated_results = []

        for question_idx, question_data in enumerate(self.uncertainty_results):
            print(
                f"\n--- Processing Question {question_idx + 1}/{len(self.uncertainty_results)} ---")

            try:
                # Update this question's results
                updated_question = self._update_question_results(question_data)
                updated_results.append(updated_question)

                # Save progress after each question
                self._save_updated_results(
                    updated_results + self.uncertainty_results[question_idx + 1:])

                print(f"Completed question {question_idx + 1}")

            except Exception as e:
                print(f"Error processing question {question_idx + 1}: {e}")
                import traceback
                traceback.print_exc()
                # Add the original question data if processing failed
                updated_results.append(question_data)
                continue

        print(f"\n=== Tools Confidence Analysis Complete ===")
        print(f"Processed {len(updated_results)} questions")

        # Final save
        self._save_updated_results(updated_results)

        return updated_results

    def generate_summary_stats(self) -> Dict[str, Any]:
        """Generate summary statistics for tool confidence analysis."""
        if not self.uncertainty_results:
            return {}

        tool_confidence_scores = []
        total_claims = 0

        for question_data in self.uncertainty_results:
            claim_uncertainties = question_data.get('claim_uncertainties', [])

            for claim_data in claim_uncertainties:
                uncertainty_metrics = claim_data.get('uncertainty_metrics', {})
                tool_confidence = uncertainty_metrics.get('tool_confidence')

                if tool_confidence is not None:
                    tool_confidence_scores.append(tool_confidence)

                total_claims += 1

        if not tool_confidence_scores:
            return {"error": "No tool confidence scores found"}

        # Calculate statistics
        import numpy as np

        stats = {
            "total_claims": total_claims,
            "claims_with_tool_confidence": len(tool_confidence_scores),
            "tool_confidence_stats": {
                "mean": float(np.mean(tool_confidence_scores)),
                "median": float(np.median(tool_confidence_scores)),
                "std": float(np.std(tool_confidence_scores)),
                "min": float(np.min(tool_confidence_scores)),
                "max": float(np.max(tool_confidence_scores))
            },
            "tool_confidence_distribution": {
                "score_0": int(np.sum(np.array(tool_confidence_scores) == 0)),
                "score_1": int(np.sum(np.array(tool_confidence_scores) == 1))
            }
        }

        # Calculate percentages
        total_scored = len(tool_confidence_scores)
        if total_scored > 0:
            stats["tool_confidence_percentages"] = {
                "score_0": (stats["tool_confidence_distribution"]["score_0"] / total_scored) * 100,
                "score_1": (stats["tool_confidence_distribution"]["score_1"] / total_scored) * 100
            }

        return stats


def main():
    """Main function to run tools confidence analysis."""
    analyzer = ToolsConfidenceAnalyzer()

    # Run the analysis
    updated_results = analyzer.analyze_tools_confidence()

    # Generate and save summary statistics
    summary_stats = analyzer.generate_summary_stats()

    if summary_stats:
        with open("tools_confidence_summary.json", 'w', encoding='utf-8') as f:
            json.dump(summary_stats, f, ensure_ascii=False, indent=2)

        print(f"\n=== Tools Confidence Summary ===")
        print(f"Total claims: {summary_stats.get('total_claims', 'N/A')}")
        print(
            f"Claims with tool confidence: {summary_stats.get('claims_with_tool_confidence', 'N/A')}")

        if 'tool_confidence_stats' in summary_stats:
            stats = summary_stats['tool_confidence_stats']
            print(f"Mean tool confidence: {stats.get('mean', 'N/A'):.3f}")
            print(f"Median tool confidence: {stats.get('median', 'N/A'):.3f}")

        if 'tool_confidence_distribution' in summary_stats:
            dist = summary_stats['tool_confidence_distribution']
            print(
                f"Distribution - Score 0: {dist.get('score_0', 'N/A')}, Score 1: {dist.get('score_1', 'N/A')}")

        print(f"Summary saved to: tools_confidence_summary.json")


if __name__ == "__main__":
    main()
