import pandas as pd
import re
import time
import numpy as np

from typing import List, Tuple

from src.claude import call_claude  # change this based on your API



def extract_answer(response: str) -> str:
    lines = response.strip().split('\n')
    for line in reversed(lines[-3:]):
        line = line.strip().upper()
        if line in ['NN', 'OPT', '"NN"', '"OPT"', 'NN.', 'OPT.']:
            return line.replace('"', '').replace('.', '')
    
    pattern = r'\b(NN|OPT)\b'
    matches = re.findall(pattern, response.upper())
    
    if matches:
        return matches[-1]
    
    if re.search(r'\bnearest\s+neighbor\b', response.lower()):
        return 'NN'
    elif re.search(r'\boptimized\b', response.lower()):
        return 'OPT'
    
    print(f"Warning: Could not extract answer from response")
    return 'UNKNOWN'


def evaluate_batch(df_batch: pd.DataFrame) -> Tuple[List[str], List[str]]:
    prompts = df_batch['prompt_best_comparison'].tolist()
    ground_truth = df_batch['label_best_comparison'].tolist()
    responses = call_claude(prompts, model_name='claude3.7-sonnet', temperature=1.0)
    predictions = [extract_answer(resp) for resp in responses]
    return predictions, ground_truth


def calculate_accuracy(predictions: List[str], ground_truth: List[str]) -> float:
    valid_pairs = [(p, g) for p, g in zip(predictions, ground_truth) if p != 'UNKNOWN']
    
    if not valid_pairs:
        return 0.0
    
    correct = sum(1 for p, g in valid_pairs if p == g)
    return correct / len(valid_pairs)


def main():
    csv_path = 'data/prototype_comparison_prompts_extended.csv'
    print(f"Loading data from {csv_path}")
    df = pd.read_csv(csv_path)
    total_rows = len(df)
    print(f"Loaded {total_rows} rows")
    
    if 'prompt_best_comparison' not in df.columns:
        print("ERROR: 'prompt_best_comparison' column not found in CSV")
        print("Available columns:", df.columns.tolist())
        return
    
    # Remove rows where prompt_best_comparison is null/empty
    df = df.dropna(subset=['prompt_best_comparison'])
    df = df[df['prompt_best_comparison'].str.strip() != '']
    total_rows = len(df)
    print(f"After filtering for valid best comparison prompts: {total_rows} rows")
    
    if total_rows == 0:
        print("ERROR: No valid best comparison prompts found!")
        return
    
    # Initialize tracking variables
    batch_size = 10
    predictions_all = []
    ground_truth_all = []
    
    # Process in batches
    num_batches = (total_rows + batch_size - 1) // batch_size
    
    print(f"\nProcessing {total_rows} rows in {num_batches} batches of {batch_size}")
    print(f"This will make {total_rows} total LLM queries\n")
    print("-" * 80)
    
    start_time = time.time()
    
    for batch_idx in range(num_batches):
        batch_start = batch_idx * batch_size
        batch_end = min(batch_start + batch_size, total_rows)
        df_batch = df.iloc[batch_start:batch_end]
        
        batch_num = batch_idx + 1
        print(f"\nBatch {batch_num}/{num_batches} (rows {batch_start+1}-{batch_end})")
        
        # Evaluate best comparison prompts
        print("  Evaluating best comparison prompts...")
        preds, truth = evaluate_batch(df_batch)
        predictions_all.extend(preds)
        ground_truth_all.extend(truth)
        
        # Calculate running accuracy
        acc = calculate_accuracy(predictions_all, ground_truth_all)
        
        # Count extraction failures
        failures = predictions_all.count('UNKNOWN')
        
        # Print batch results
        print(f"\n  Running Accuracy after {batch_end} rows:")
        print(f"    Best Comparison:  {acc:.2%} (failures: {failures})")
        
        # Count predictions for each class
        nn_preds = predictions_all.count('NN')
        opt_preds = predictions_all.count('OPT')
        nn_truth = ground_truth_all.count('NN')
        opt_truth = ground_truth_all.count('OPT')
        
        print(f"    Predictions: NN={nn_preds}, OPT={opt_preds}")
        print(f"    Ground Truth: NN={nn_truth}, OPT={opt_truth}")
        
        # Time estimate
        elapsed = time.time() - start_time
        avg_time_per_batch = elapsed / batch_num
        remaining_batches = num_batches - batch_num
        eta_seconds = remaining_batches * avg_time_per_batch
        eta_minutes = eta_seconds / 60
        
        print(f"  Time elapsed: {elapsed/60:.1f} min | ETA: {eta_minutes:.1f} min")
        print("-" * 80)
    
    # Final results
    print("\n" + "=" * 80)
    print("FINAL RESULTS")
    print("=" * 80)
    
    acc_final = calculate_accuracy(predictions_all, ground_truth_all)
    
    print(f"\nTotal rows processed: {total_rows}")
    print(f"Total LLM queries made: {total_rows}")
    
    print(f"\nFinal Accuracy:")
    print(f"  Best Comparison Model: {acc_final:.2%}")
    
    # Detailed breakdown
    nn_preds_final = predictions_all.count('NN')
    opt_preds_final = predictions_all.count('OPT')
    nn_truth_final = ground_truth_all.count('NN')
    opt_truth_final = ground_truth_all.count('OPT')
    failures_final = predictions_all.count('UNKNOWN')
    
    print(f"\nPrediction Distribution:")
    print(f"  Predicted NN: {nn_preds_final} ({nn_preds_final/total_rows:.1%})")
    print(f"  Predicted OPT: {opt_preds_final} ({opt_preds_final/total_rows:.1%})")
    print(f"  Failures: {failures_final} ({failures_final/total_rows:.1%})")
    
    print(f"\nGround Truth Distribution:")
    print(f"  Actual NN closer: {nn_truth_final} ({nn_truth_final/total_rows:.1%})")
    print(f"  Actual OPT closer: {opt_truth_final} ({opt_truth_final/total_rows:.1%})")
    
    # Calculate precision and recall for each class
    valid_predictions = [p for p in predictions_all if p != 'UNKNOWN']
    valid_ground_truth = [g for p, g in zip(predictions_all, ground_truth_all) if p != 'UNKNOWN']
    
    if valid_predictions:
        # NN metrics
        nn_tp = sum(1 for p, g in zip(valid_predictions, valid_ground_truth) if p == 'NN' and g == 'NN')
        nn_fp = sum(1 for p, g in zip(valid_predictions, valid_ground_truth) if p == 'NN' and g == 'OPT')
        nn_fn = sum(1 for p, g in zip(valid_predictions, valid_ground_truth) if p == 'OPT' and g == 'NN')
        
        nn_precision = nn_tp / (nn_tp + nn_fp) if (nn_tp + nn_fp) > 0 else 0
        nn_recall = nn_tp / (nn_tp + nn_fn) if (nn_tp + nn_fn) > 0 else 0
        
        # OPT metrics
        opt_tp = sum(1 for p, g in zip(valid_predictions, valid_ground_truth) if p == 'OPT' and g == 'OPT')
        opt_fp = sum(1 for p, g in zip(valid_predictions, valid_ground_truth) if p == 'OPT' and g == 'NN')
        opt_fn = sum(1 for p, g in zip(valid_predictions, valid_ground_truth) if p == 'NN' and g == 'OPT')
        
        opt_precision = opt_tp / (opt_tp + opt_fp) if (opt_tp + opt_fp) > 0 else 0
        opt_recall = opt_tp / (opt_tp + opt_fn) if (opt_tp + opt_fn) > 0 else 0
        
        print(f"\nDetailed Metrics (excluding failures):")
        print(f"  NN - Precision: {nn_precision:.2%}, Recall: {nn_recall:.2%}")
        print(f"  OPT - Precision: {opt_precision:.2%}, Recall: {opt_recall:.2%}")
    
    # Save results to CSV
    results_df = df.copy()
    results_df['pred_best'] = predictions_all
    results_df['correct_best'] = [p == g for p, g in zip(predictions_all, ground_truth_all)]
    
    output_path = 'prototype_best_comparison_results.csv'
    results_df.to_csv(output_path, index=False)
    print(f"\nResults saved to {output_path}")
    



if __name__ == "__main__":
    main()