#!/usr/bin/env python3
"""
Tools Confidence Analysis for Climate Claims
Evaluates whether claims can be verified using available climate simulation tools.
"""

import os
import json
import time
from typing import Dict, List, Any
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
import copy
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

load_dotenv()


class ToolsConfidenceAnalyzer:
    """Analyze claims for tool-based verification confidence."""

    def __init__(self):
        """Initialize the tools confidence analyzer."""
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

        # Load uncertainty analysis results
        self.uncertainty_results = self._load_uncertainty_results()

        print(f"Initialized ToolsConfidenceAnalyzer")
        print(
            f"Loaded {len(self.uncertainty_results)} questions from uncertainty analysis")

    def _load_uncertainty_results(self) -> List[Dict]:
        """Load uncertainty analysis results."""
        try:
            with open("uncertainty_analysis_results.json", 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            print("❌ uncertainty_analysis_results.json not found!")
            return []

    def _get_tools_prompt(self) -> str:
        """Get the detailed tools description prompt."""
        return """
You are an expert in climate science and computational tools. You will evaluate whether climate-related claims can be verified using available climate simulation tools.

**AVAILABLE CLIMATE TOOLS:**

1. **query_lat_and_lon**
   - Purpose: Finds latitude and longitude coordinates for a given city name
   - Input: city_name (string)
   - Output: latitude and longitude coordinates
   - Use case: Geographic location lookup for climate analysis

2. **history_temperature**
   - Purpose: Retrieves historical average temperature for a specific location and year
   - Input: latitude (number), longitude (number), year (integer, 1850-2014)
   - Output: Historical temperature data and description
   - Use case: Getting baseline temperature data for climate comparisons

3. **future_temperature**
   - Purpose: Predicts future average temperature for a specific location and year under climate scenarios
   - Input: latitude (number), longitude (number), year (integer, 2015-2100), setting (climate scenario: ssp126, ssp245, ssp370, ssp585)
   - Output: Future temperature predictions and description
   - Use case: Climate projection analysis for different emission scenarios

4. **diy_greenhouse**
   - Purpose: Models local temperature impact of custom CO2 and CH4 emission changes
   - Input: latitude (number), longitude (number), year (integer), setting (scenario), delta_CO2 (percentage change), delta_CH4 (percentage change)
   - Output: Temperature predictions with greenhouse gas modifications
   - Use case: Assessing impact of greenhouse gas emission changes on local climate

5. **diy_aerosol**
   - Purpose: Models local temperature impact of custom SO2 and Black Carbon (BC) emission changes at specified points
   - Input: latitude (number), longitude (number), year (integer), setting (scenario), delta_SO2 (percentage), delta_BC (percentage), modify_points (coordinates where changes are applied)
   - Output: Temperature predictions with aerosol modifications
   - Use case: Evaluating aerosol intervention effects on local climate

6. **diy_aerosol_mean**
   - Purpose: Models global average temperature impact of custom SO2 and BC emission changes at specified points
   - Input: year (integer), setting (scenario), delta_SO2 (percentage), delta_BC (percentage), modify_points (coordinates)
   - Output: Global average temperature changes from aerosol modifications
   - Use case: Assessing global-scale effects of aerosol interventions

7. **is_land_or_sea**
   - Purpose: Determines if coordinates are located on land or sea
   - Input: longitude (number), latitude (number)
   - Output: Land/sea classification (1 for land, 0 for sea)
   - Use case: Geographic terrain classification for climate analysis

8. **diff_diy_aerosol_mean**
   - Purpose: Calculates difference in global average temperature caused by aerosol emission changes
   - Input: year (integer), setting (scenario), delta_SO2 (percentage), delta_BC (percentage), modify_points (coordinates)
   - Output: Temperature difference compared to baseline
   - Use case: Quantifying net global temperature effects of aerosol changes

9. **diy_greenhouse_summary**
   - Purpose: Provides comprehensive summary of temperature predictions under different scenarios with custom greenhouse gas changes
   - Input: longitude (number), latitude (number), delta_CO2 (percentage), delta_CH4 (percentage)
   - Output: Summary text of temperature predictions across scenarios
   - Use case: Getting overview of greenhouse gas impact across multiple scenarios

10. **location_summary**
    - Purpose: Retrieves comprehensive historical and future temperature data for a location
    - Input: longitude (number), latitude (number)
    - Output: Complete temperature data dictionary and description
    - Use case: Getting comprehensive climate profile for a specific location

**EVALUATION TASK:**
For each claim provided, you need to determine how well the available tools can help verify or assess the accuracy of that claim.

**SCORING CRITERIA:**
- **0**: The claim cannot be verified or assessed using any of the available tools
  - Examples: Claims about non-climate topics, general policy statements, claims requiring data/tools not available, claims that touch on climate aspects but cannot be directly verified with the specific tools provided
  
- **1**: The claim can be directly and comprehensively verified or assessed using the available tools
  - Examples: Claims about temperature changes, climate scenarios, aerosol/greenhouse gas impacts, geographic classifications, specific quantitative climate predictions

**RESPONSE FORMAT:**
Respond with ONLY a JSON object containing a single key "tool_confidence" with a value of 0 or 1.

Example: {{"tool_confidence": 1}}

**YOUR TASK:**
Question: {question}
Claim: {claim}

Evaluate how well the available climate simulation tools can verify or assess this specific claim in the context of the given question.
"""

    def _evaluate_claim_tool_confidence(self, question: str, claim: str) -> float:
        """Evaluate tool confidence for a single claim."""
        prompt = self._get_tools_prompt().format(question=question, claim=claim)

        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=100,
                    temperature=0.1
                )

                result_text = response.choices[0].message.content.strip()

                # Handle markdown code blocks
                if result_text.startswith('```json'):
                    result_text = result_text.replace(
                        '```json', '').replace('```', '').strip()
                elif result_text.startswith('```'):
                    result_text = result_text.replace('```', '').strip()

                result_json = json.loads(result_text)

                tool_confidence = result_json.get('tool_confidence', 0)

                # Validate the score
                if tool_confidence in [0, 1]:
                    return float(tool_confidence)
                else:
                    print(
                        f"Warning: Invalid tool_confidence score {tool_confidence}, using 0")
                    return 0.0

            except json.JSONDecodeError as e:
                print(f"JSON decode error on attempt {attempt + 1}: {e}")
                print(f"Raw response: {result_text}")
                if attempt < max_retries - 1:
                    time.sleep(2)
                else:
                    print("Failed to parse JSON after all retries, using 0")
                    return 0.0

            except Exception as e:
                print(f"API error on attempt {attempt + 1}: {e}")
                if attempt < max_retries - 1:
                    time.sleep(5)
                else:
                    print("Failed API call after all retries, using 0")
                    return 0.0

        return 0.0

    def _process_single_claim(self, question: str, claim_idx: int, claim_data: Dict) -> tuple[int, Dict]:
        """Process a single claim to add tool confidence."""
        updated_claim_data = copy.deepcopy(claim_data)
        claim = claim_data.get('claim', '')

        if not claim:
            print(f"  Skipping empty claim {claim_idx}")
            return claim_idx, updated_claim_data

        # Get tool confidence
        tool_confidence = self._evaluate_claim_tool_confidence(question, claim)

        # Add tool_confidence to uncertainty_metrics
        if 'uncertainty_metrics' not in updated_claim_data:
            updated_claim_data['uncertainty_metrics'] = {}

        updated_claim_data['uncertainty_metrics']['tool_confidence'] = float(
            tool_confidence)

        print(f"    Claim {claim_idx+1} tool confidence: {tool_confidence}")

        return claim_idx, updated_claim_data

    def _update_question_results(self, question_data: Dict) -> Dict:
        """Update a single question's results with tool confidence using multithreading."""
        updated_question = copy.deepcopy(question_data)
        question = updated_question.get('question', '')
        claim_uncertainties = updated_question.get('claim_uncertainties', [])

        print(f"\nProcessing question: {question[:100]}...")
        print(
            f"Processing {len(claim_uncertainties)} claims with multithreading (batch size: 4)")

        if not claim_uncertainties:
            return updated_question

        # Process claims in batches with multithreading
        batch_size = 4
        total_claims = len(claim_uncertainties)
        results_buffer = {}  # Store results by index to maintain order

        for batch_start in range(0, total_claims, batch_size):
            batch_end = min(batch_start + batch_size, total_claims)
            batch_claims = [(i, claim_uncertainties[i])
                            for i in range(batch_start, batch_end)]

            print(
                f"  Processing batch {batch_start//batch_size + 1}: Claims {batch_start + 1}-{batch_end}")

            # Use ThreadPoolExecutor for parallel processing
            with ThreadPoolExecutor(max_workers=min(4, len(batch_claims))) as executor:
                # Submit all claims in the batch
                future_to_idx = {
                    executor.submit(self._process_single_claim, question, idx, claim_data): idx
                    for idx, claim_data in batch_claims
                }

                # Collect results as they complete
                for future in as_completed(future_to_idx):
                    original_idx = future_to_idx[future]
                    try:
                        claim_idx, updated_claim_data = future.result()
                        results_buffer[claim_idx] = updated_claim_data
                    except Exception as e:
                        print(
                            f"    Error processing claim {original_idx + 1}: {e}")
                        # Keep original data if processing failed
                        results_buffer[original_idx] = claim_uncertainties[original_idx]

        # Update the question with processed results in the correct order
        for i in range(total_claims):
            if i in results_buffer:
                updated_question['claim_uncertainties'][i] = results_buffer[i]

        return updated_question

    def _save_updated_results(self, updated_results: List[Dict]):
        """Save updated results back to the JSON file."""
        # Create backup
        # backup_filename = f"uncertainty_analysis_results_backup_{int(time.time())}.json"
        # try:
        #     with open(backup_filename, 'w', encoding='utf-8') as f:
        #         json.dump(self.uncertainty_results, f,
        #                   ensure_ascii=False, indent=2)
        #     print(f"Created backup: {backup_filename}")
        # except Exception as e:
        #     print(f"Warning: Could not create backup: {e}")

        # Save updated results
        try:
            # Convert any numpy types before saving
            def convert_numpy_types(obj):
                """Convert numpy types to Python native types for JSON serialization."""
                import numpy as np
                if isinstance(obj, np.bool_):
                    return bool(obj)
                elif isinstance(obj, np.integer):
                    return int(obj)
                elif isinstance(obj, np.floating):
                    return float(obj)
                elif isinstance(obj, np.ndarray):
                    return obj.tolist()
                elif isinstance(obj, dict):
                    return {key: convert_numpy_types(value) for key, value in obj.items()}
                elif isinstance(obj, list):
                    return [convert_numpy_types(item) for item in obj]
                else:
                    return obj

            clean_results = convert_numpy_types(updated_results)

            with open("uncertainty_analysis_results.json", 'w', encoding='utf-8') as f:
                json.dump(clean_results, f, ensure_ascii=False, indent=2)
            print("Successfully updated uncertainty_analysis_results.json")
        except Exception as e:
            print(f"Error saving updated results: {e}")
            import traceback
            traceback.print_exc()
            raise

    def analyze_tools_confidence(self):
        """Main method to analyze tool confidence for all claims."""
        print(f"\n=== Starting Tools Confidence Analysis ===")
        print(f"Total questions to process: {len(self.uncertainty_results)}")

        updated_results = []

        for question_idx, question_data in enumerate(self.uncertainty_results):
            print(
                f"\n--- Processing Question {question_idx + 1}/{len(self.uncertainty_results)} ---")

            try:
                # Update this question's results
                updated_question = self._update_question_results(question_data)
                updated_results.append(updated_question)

                # Save progress after each question
                self._save_updated_results(
                    updated_results + self.uncertainty_results[question_idx + 1:])

                print(f"Completed question {question_idx + 1}")

            except Exception as e:
                print(f"Error processing question {question_idx + 1}: {e}")
                import traceback
                traceback.print_exc()
                # Add the original question data if processing failed
                updated_results.append(question_data)
                continue

        print(f"\n=== Tools Confidence Analysis Complete ===")
        print(f"Processed {len(updated_results)} questions")

        # Final save
        self._save_updated_results(updated_results)

        return updated_results

    def generate_summary_stats(self) -> Dict[str, Any]:
        """Generate summary statistics for tool confidence analysis."""
        if not self.uncertainty_results:
            return {}

        tool_confidence_scores = []
        total_claims = 0

        for question_data in self.uncertainty_results:
            claim_uncertainties = question_data.get('claim_uncertainties', [])

            for claim_data in claim_uncertainties:
                uncertainty_metrics = claim_data.get('uncertainty_metrics', {})
                tool_confidence = uncertainty_metrics.get('tool_confidence')

                if tool_confidence is not None:
                    tool_confidence_scores.append(tool_confidence)

                total_claims += 1

        if not tool_confidence_scores:
            return {"error": "No tool confidence scores found"}

        # Calculate statistics
        import numpy as np

        stats = {
            "total_claims": total_claims,
            "claims_with_tool_confidence": len(tool_confidence_scores),
            "tool_confidence_stats": {
                "mean": float(np.mean(tool_confidence_scores)),
                "median": float(np.median(tool_confidence_scores)),
                "std": float(np.std(tool_confidence_scores)),
                "min": float(np.min(tool_confidence_scores)),
                "max": float(np.max(tool_confidence_scores))
            },
            "tool_confidence_distribution": {
                "score_0": int(np.sum(np.array(tool_confidence_scores) == 0)),
                "score_1": int(np.sum(np.array(tool_confidence_scores) == 1))
            }
        }

        # Calculate percentages
        total_scored = len(tool_confidence_scores)
        if total_scored > 0:
            stats["tool_confidence_percentages"] = {
                "score_0": (stats["tool_confidence_distribution"]["score_0"] / total_scored) * 100,
                "score_1": (stats["tool_confidence_distribution"]["score_1"] / total_scored) * 100
            }

        return stats


def main():
    """Main function to run tools confidence analysis."""
    analyzer = ToolsConfidenceAnalyzer()

    # Run the analysis
    updated_results = analyzer.analyze_tools_confidence()

    # Generate and save summary statistics
    summary_stats = analyzer.generate_summary_stats()

    if summary_stats:
        with open("tools_confidence_summary.json", 'w', encoding='utf-8') as f:
            json.dump(summary_stats, f, ensure_ascii=False, indent=2)

        print(f"\n=== Tools Confidence Summary ===")
        print(f"Total claims: {summary_stats.get('total_claims', 'N/A')}")
        print(
            f"Claims with tool confidence: {summary_stats.get('claims_with_tool_confidence', 'N/A')}")

        if 'tool_confidence_stats' in summary_stats:
            stats = summary_stats['tool_confidence_stats']
            print(f"Mean tool confidence: {stats.get('mean', 'N/A'):.3f}")
            print(f"Median tool confidence: {stats.get('median', 'N/A'):.3f}")

        if 'tool_confidence_distribution' in summary_stats:
            dist = summary_stats['tool_confidence_distribution']
            print(
                f"Distribution - Score 0: {dist.get('score_0', 'N/A')}, Score 1: {dist.get('score_1', 'N/A')}")

        print(f"Summary saved to: tools_confidence_summary.json")


if __name__ == "__main__":
    main()
