#!/usr/bin/env python3
"""
Evaluation script (GPT only) to generate a Table 1-like CSV for the Epidemiology dataset.
Outputs evaluation.csv with columns: Dataset, Method, Threshold, AUROC, AUPR, F1, Precision, Recall
This replicates the GPT logic in meta_analysis.py (without Claude) for methods: random, verbalized, uncertainty, ours at thresholds 15/25/45.
"""

import json
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple
from sklearn import metrics
from sklearn.metrics import precision_recall_curve, f1_score, precision_score, recall_score


DATASET_NAME = 'Epidemiology'


def load_json(path: Path) -> Dict:
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Warning: {path} not found")
        return {}
    except json.JSONDecodeError as e:
        print(f"Error decoding {path}: {e}")
        return {}


def extract_claims_with_uncertainty_and_correctness(base_dir: Path) -> List[Dict]:
    claims_data: List[Dict] = []

    uncertainty_results = load_json(
        base_dir / 'uncertainty_analysis_results.json')
    rag_correctness = load_json(
        base_dir / 'rag_claim_correctness_results.json')

    if not uncertainty_results or not rag_correctness:
        return []

    correctness_results = rag_correctness.get('results', [])

    for question_idx, question_result in enumerate(correctness_results):
        question = question_result.get('question')
        claims_analysis = question_result.get('claims_analysis', [])

        # Find corresponding uncertainty data
        uncertainty_question = None
        for uq in uncertainty_results:
            if uq.get('question') == question:
                uncertainty_question = uq
                break

        if not uncertainty_question:
            continue

        claim_uncertainties = uncertainty_question.get(
            'claim_uncertainties', [])

        # Match claims between correctness and uncertainty data
        for claim_correctness in claims_analysis:
            original_claim = claim_correctness.get('original_claim', '')

            for claim_uncertainty in claim_uncertainties:
                uncertainty_claim = claim_uncertainty.get('claim', '')

                if original_claim and uncertainty_claim and (
                    original_claim == uncertainty_claim or
                    original_claim in uncertainty_claim or
                    uncertainty_claim in original_claim
                ):
                    metrics_dict = dict(claim_uncertainty.get(
                        'uncertainty_metrics', {}))
                    # Normalize potential misspelled key
                    if 'betweeness_centrality' in metrics_dict and 'betweenness_centrality' not in metrics_dict:
                        metrics_dict['betweenness_centrality'] = metrics_dict['betweeness_centrality']

                    combined_claim = {
                        'question_idx': question_idx,
                        'question': question,
                        'claim': original_claim,
                        'original_is_correct': claim_correctness.get('original_is_correct', False),
                        'final_is_correct': claim_correctness.get('final_is_correct', False),
                        'was_updated': claim_correctness.get('was_updated', False),
                        'is_included': claim_correctness.get('is_included', False),
                        'uncertainty_metrics': metrics_dict,
                    }
                    claims_data.append(combined_claim)
                    break

    return claims_data


def calculate_threshold_for_rag_rate(claims_data: List[Dict], metric: str, target_rate: float) -> float:
    metric_values = [
        claim['uncertainty_metrics'].get(
            metric, 1.0 if metric == 'verbalized_confidence' else 0.0)
        for claim in claims_data
        if 'uncertainty_metrics' in claim
    ]

    if not metric_values:
        return 0.5

    sorted_values = sorted(metric_values)
    threshold_index = int(len(sorted_values) * target_rate)
    threshold_index = min(threshold_index, len(sorted_values) - 1)
    return sorted_values[threshold_index]


def apply_ours_method_with_indices(claims_data: List[Dict], target_rate: float, uncertainty_metric: str = 'closeness_centrality') -> Tuple[List[float], List[int]]:
    predictions = [claim['uncertainty_metrics'].get(
        uncertainty_metric, 0.0) for claim in claims_data]
    rag_indices: List[int] = []

    # Filter claims with tool_confidence = 1
    tool_filtered_claims = []
    claim_indices = []
    for i, claim in enumerate(claims_data):
        tool_conf = claim['uncertainty_metrics'].get('tool_confidence', 0.0)
        if tool_conf == 1.0:
            tool_filtered_claims.append(claim)
            claim_indices.append(i)

    total_target_rag_count = int(len(claims_data) * target_rate)

    if tool_filtered_claims:
        claim_index_pairs = list(zip(tool_filtered_claims, claim_indices))
        claim_index_pairs.sort(
            key=lambda x: x[0]['uncertainty_metrics'].get(uncertainty_metric, 1.0))

        actual_rag_count = min(total_target_rag_count,
                               len(tool_filtered_claims))
        selected_pairs = claim_index_pairs[:actual_rag_count]

        for claim, original_idx in selected_pairs:
            rag_indices.append(original_idx)
            if claim.get('is_included', False):
                predictions[original_idx] = 1.0

    return predictions, rag_indices


def apply_rag_method(claims_data: List[Dict], method: str, threshold_pct: float,
                     uncertainty_metric: str = 'closeness_centrality', prediction_threshold: float = None) -> Tuple[List[float], List[int]]:
    if not claims_data:
        return [], []

    predictions: List[float] = []
    ground_truth: List[int] = []
    target_rate = threshold_pct / 100.0

    if method == 'random':
        import random
        random.seed(42)
        num_claims = len(claims_data)
        target_rag_count = int(num_claims * target_rate)
        rag_indices = set(random.sample(range(num_claims),
                          min(target_rag_count, num_claims)))

        for i, claim in enumerate(claims_data):
            uncertainty_value = claim['uncertainty_metrics'].get(
                'closeness_centrality', 0.0)
            is_included = claim.get('is_included', False)
            is_ragged = i in rag_indices
            if is_ragged:
                pred = 1.0 if is_included else uncertainty_value
                gt = int(claim.get('final_is_correct', False))
            else:
                pred = uncertainty_value
                gt = int(claim.get('original_is_correct', False))
            predictions.append(pred)
            ground_truth.append(gt)

    elif method == 'verbalized':
        threshold = calculate_threshold_for_rag_rate(
            claims_data, 'verbalized_confidence', target_rate)
        verbalized_values = [claim['uncertainty_metrics'].get(
            'verbalized_confidence', 0.5) for claim in claims_data]
        claims_at_threshold = sum(
            1 for v in verbalized_values if v <= threshold)
        target_count = int(len(verbalized_values) * target_rate)
        claims_below_threshold = sum(
            1 for v in verbalized_values if v < threshold)
        claims_exactly_at_threshold = claims_at_threshold - claims_below_threshold

        if claims_exactly_at_threshold > 0:
            needed_from_threshold = target_count - claims_below_threshold
            selection_rate = max(
                0.0, min(1.0, needed_from_threshold / claims_exactly_at_threshold))
        else:
            selection_rate = 1.0

        for claim in claims_data:
            verbalized_value = claim['uncertainty_metrics'].get(
                'verbalized_confidence', 0.5)
            is_included = claim.get('is_included', False)
            is_below_threshold = verbalized_value < threshold
            is_at_threshold = verbalized_value == threshold

            should_get_rag = False
            if is_below_threshold:
                should_get_rag = True
            elif is_at_threshold:
                import random
                random.seed(hash(claim.get('claim', '')) %
                            10000 + int(threshold_pct))
                should_get_rag = random.random() < selection_rate

            if should_get_rag:
                pred = 1.0 if is_included else verbalized_value
                gt = int(claim.get('final_is_correct', False))
            else:
                pred = verbalized_value
                gt = int(claim.get('original_is_correct', False))
            predictions.append(pred)
            ground_truth.append(gt)

    elif method == 'uncertainty':
        threshold = calculate_threshold_for_rag_rate(
            claims_data, uncertainty_metric, target_rate)
        metric_values = [claim['uncertainty_metrics'].get(
            uncertainty_metric, 0.0) for claim in claims_data]
        target_count = int(len(metric_values) * target_rate)
        values_below = sum(1 for v in metric_values if v < threshold)
        values_at = sum(1 for v in metric_values if v <=
                        threshold) - values_below
        if values_at > 0:
            needed_from_threshold = max(0, target_count - values_below)
            selection_rate = max(
                0.0, min(1.0, needed_from_threshold / values_at))
        else:
            selection_rate = 1.0

        for claim in claims_data:
            uncertainty_value = claim['uncertainty_metrics'].get(
                uncertainty_metric, 0.0)
            is_included = claim.get('is_included', False)
            is_below = uncertainty_value < threshold
            is_at = uncertainty_value == threshold

            should_get_rag = False
            if is_below:
                should_get_rag = True
            elif is_at:
                import random
                random.seed(hash(claim.get('claim', '')) %
                            10000 + int(threshold_pct))
                should_get_rag = random.random() < selection_rate

            if should_get_rag:
                pred = 1.0 if is_included else uncertainty_value
                gt = int(claim.get('final_is_correct', False))
            else:
                pred = uncertainty_value
                gt = int(claim.get('original_is_correct', False))
            predictions.append(pred)
            ground_truth.append(gt)

    elif method == 'ours':
        predictions, rag_indices = apply_ours_method_with_indices(
            claims_data, target_rate, uncertainty_metric)
        for i, claim in enumerate(claims_data):
            gt = int(claim.get('final_is_correct', False)) if i in rag_indices else int(
                claim.get('original_is_correct', False))
            ground_truth.append(gt)

        if prediction_threshold is not None:
            bin_preds = []
            for i, pred in enumerate(predictions):
                if i in rag_indices and pred == 1.0:
                    bin_preds.append(1.0)
                else:
                    bin_preds.append(
                        1.0 if pred >= prediction_threshold else 0.0)
            predictions = bin_preds

    return predictions, ground_truth


def calculate_all_metrics(predictions: List[float], ground_truth: List[int]) -> Dict[str, float]:
    if not predictions or not ground_truth:
        return {'auroc': 0.5, 'aupr': 0.5, 'f1': 0.0, 'precision': 0.0, 'recall': 0.0}

    auroc = metrics.roc_auc_score(ground_truth, predictions)
    aupr = metrics.average_precision_score(ground_truth, predictions)
    precision_vals, recall_vals, thresholds = precision_recall_curve(
        ground_truth, predictions)

    precision_recall_diff = np.abs(precision_vals[:-1] - recall_vals[:-1])
    optimal_idx = int(np.argmin(precision_recall_diff))
    optimal_threshold = float(thresholds[optimal_idx])

    binary_predictions = (np.array(predictions) >=
                          optimal_threshold).astype(int)
    f1 = f1_score(ground_truth, binary_predictions)
    precision = precision_score(
        ground_truth, binary_predictions, zero_division=0)
    recall = recall_score(ground_truth, binary_predictions, zero_division=0)

    return {
        'auroc': round(auroc, 4),
        'aupr': round(aupr, 4),
        'f1': round(f1, 4),
        'precision': round(precision, 4),
        'recall': round(recall, 4),
    }


def main():
    base_dir = Path(__file__).resolve().parent

    methods = ["random", "verbalized", "uncertainty", "ours"]
    thresholds = [15, 25, 45]

    claims_data = extract_claims_with_uncertainty_and_correctness(base_dir)
    print(f"Loaded {len(claims_data)} claims for {DATASET_NAME} GPT-4o")

    results = []
    for method in methods:
        for threshold in thresholds:
            try:
                predictions, ground_truth = apply_rag_method(
                    claims_data, method, threshold, uncertainty_metric='closeness_centrality')
                metrics_dict = calculate_all_metrics(predictions, ground_truth)
                results.append({
                    'Dataset': DATASET_NAME,
                    'Method': method,
                    'Threshold': f"{threshold}%",
                    'AUROC': metrics_dict['auroc'],
                    'AUPR': metrics_dict['aupr'],
                    'F1': metrics_dict['f1'],
                    'Precision': metrics_dict['precision'],
                    'Recall': metrics_dict['recall'],
                })
            except Exception as e:
                print(
                    f"Error processing {DATASET_NAME} {method} {threshold}%: {e}")

    df = pd.DataFrame(results)
    out_path = base_dir / 'evaluation.csv'
    df.to_csv(out_path, index=False)
    print(f"Evaluation saved to {out_path}")


if __name__ == '__main__':
    main()
