"""
P(Answer) Evaluation using sampling-based approach.

This module implements the P(Answer) (Probability of attempting to answer) calibration method using 
a sampling-based approach that works with vLLM and other APIs.

The ground truth P(Answer) is computed as the fraction of non-refusing samples (grades A or B) 
when generating multiple answers at a given temperature (typically T=1.0).
"""

from dataclasses import dataclass, field
import os
from typing import Literal, Optional, List, Dict, Any
import numpy as np
import pandas as pd
from datetime import datetime
import json
from pathlib import Path
from collections import Counter
from tqdm import tqdm

from src.models.predict import predict_dataset
from src.data.dataset_loader import load_by_name
from src.utils import parse_args
from src.core.calibration_metrics import (
    calibration_summary,
    convert_grades_to_binary
)
from src.evaluation.evaluate import (
    clean_output, 
    extract_answer, 
    create_timestamped_dir,
    load_model_config
)


def batch_compute_panswer_vllm(
    test_set: Dict[str, List], 
    model_config: dict,
    n_samples: int = 16,
    temperature: float = 1.0,
    num_proc: int = 50,
    batch_size: int = 100,
    prompt_template_path: str = "prompts/PROMPT_D.txt",
    grader_config: dict = None,
    verbose: bool = False
) -> tuple[List[float], List[Dict]]:
    """
    Compute P(Answer) scores for a dataset using sampling-based approach with vLLM.
    
    P(Answer) is the probability that the model will attempt to answer (not refuse).
    
    Args:
        test_set: Dataset with 'question' and 'answer' fields
        model_config: Model configuration dictionary
        n_samples: Number of samples to generate per question
        temperature: Sampling temperature
        num_proc: Number of parallel processes
        batch_size: Number of questions to process at once
        prompt_template_path: Path to prompt template (default: PROMPT_D for moderate refusal)
        grader_config: Configuration for grading model
        verbose: Whether to print progress
        
    Returns:
        Tuple of (P(Answer) scores, detailed sample data for each question)
    """
    questions = test_set['question']
    answers = test_set['answer']
    n_questions = len(questions)
    
    panswer_scores = []
    detailed_samples = []
    
    # Process in batches for efficiency
    for batch_start in tqdm(range(0, n_questions, batch_size), desc="Computing P(Answer) scores"):
        batch_end = min(batch_start + batch_size, n_questions)
        batch_questions = questions[batch_start:batch_end]
        batch_answers = answers[batch_start:batch_end]
        
        # Create expanded dataset for sampling
        expanded_questions = []
        expanded_answers = []
        question_indices = []
        
        for i, (question, answer) in enumerate(zip(batch_questions, batch_answers)):
            for _ in range(n_samples):
                expanded_questions.append(question)
                expanded_answers.append(answer)
                question_indices.append(i)
        
        # Create dataset for batch inference
        from datasets import Dataset
        batch_dataset = Dataset.from_dict({
            'question': expanded_questions,
            'answer': expanded_answers
        })
        
        if verbose:
            print(f"Generating {len(expanded_questions)} samples for batch {batch_start//batch_size + 1}")
        
        # Generate predictions for all samples
        predicted_dataset = predict_dataset(
            batch_dataset,
            model_name=model_config['model_name'],
            prompt_template_path=prompt_template_path,
            num_proc=num_proc,
            output_column="predicted_answer_raw",
            temperature=temperature,  # Use specified temperature for sampling
            max_tokens=model_config.get('max_tokens', 2048),
            suffix=model_config.get('suffix'),
            top_p=model_config.get('top_p', 0.95),
            max_thinking_tokens=model_config.get('max_thinking_tokens'),
            inference_backend=model_config['inference_backend'],
            google_api_key=model_config.get('google_api_key'),
            retry_attempts=model_config.get('retry_attempts', 5)
        )
        
        # Extract answers
        predicted_dataset = predicted_dataset.map(
            lambda x: {"predicted_answer": extract_answer(x["predicted_answer_raw"])}
        )
        
        # Grade all predictions
        if grader_config is None:
            grader_config = {
                'backend': 'openai',
                'model': 'google/gemini-2.0-flash-lite-001'
            }
            
        graded_dataset = predict_dataset(
            predicted_dataset,
            model_name=grader_config['model'],
            prompt_template_path="prompts/GRADER.txt",
            output_column="grade",
            num_proc=num_proc,
            max_tokens=10,
            inference_backend=grader_config['backend'],
            google_api_key=model_config.get('google_api_key'),
            suffix=None
        )
        
        # Clean grades
        grades = [clean_output(grade) for grade in graded_dataset['grade']]
        
        # Aggregate results per question
        for i in range(len(batch_questions)):
            # Get all data for this question
            question_indices_mask = [j for j, idx in enumerate(question_indices) if idx == i]
            question_grades = [grades[j] for j in question_indices_mask]
            question_raw_responses = [graded_dataset['predicted_answer_raw'][j] for j in question_indices_mask]
            question_predicted_answers = [predicted_dataset['predicted_answer'][j] for j in question_indices_mask]
            
            # Calculate P(Answer) as fraction of non-refusing samples (grades A or B)
            # This represents the probability of attempting to answer
            answer_count = sum(1 for grade in question_grades if grade in ['A', 'B'])
            panswer_score = answer_count / n_samples
            panswer_scores.append(panswer_score)
            
            # Store detailed sample data
            question_idx = batch_start + i
            sample_details = {
                'question_idx': question_idx,
                'question': batch_questions[i],
                'ground_truth': batch_answers[i],
                'n_samples': n_samples,
                'sample_responses': question_raw_responses,
                'sample_predicted_answers': question_predicted_answers,
                'sample_grades': question_grades,
                'panswer_score': panswer_score,
                'answer_count': answer_count,
                'grade_distribution': dict(Counter(question_grades))
            }
            detailed_samples.append(sample_details)
            
            if verbose and i < 3:  # Show first 3 examples
                print(f"\nQuestion {batch_start + i}: {batch_questions[i][:50]}...")
                print(f"Grades: {Counter(question_grades)}")
                print(f"P(Answer) = {panswer_score:.3f}")
    
    return panswer_scores, detailed_samples


@dataclass
class Arguments:
    dataset_name: str = "triviaqa"
    max_samples: int = 1000
    verbose: bool = False
    num_proc: int = 50
    model_config: str = field(default=None)  # Required
    n_samples_panswer: int = 40  # Number of samples for P(Answer) computation
    panswer_temperature: float = 1.0  # Temperature for P(Answer) sampling
    batch_size: int = 100  # Batch size for P(Answer) computation
    prompt_template: str = "PROMPT_D"  # Default to moderate refusal template
    use_same_model_for_grading: bool = False
    grader_backend: Literal["openai", "google", "vllm", "vllm_offline"] = "openai"
    grader_model: str = "google/gemini-2.0-flash-lite-001"
    # Output directory bases
    results_base: str = "results"
    logs_base: str = "logs"
    
    def __post_init__(self):
        if self.model_config is None:
            raise ValueError("model_config must be specified explicitly")
            
        # Load model config and update instance attributes
        config = load_model_config(self.model_config)
        for key, value in config.items():
            setattr(self, key, value)
            
        if self.use_same_model_for_grading:
            self.grader_backend = self.inference_backend
            self.grader_model = self.model_name


def main():
    args = parse_args(Arguments)
    print(f"Arguments: {args}")
    
    # Create output directories
    results_dir, logs_dir = create_timestamped_dir(
        "panswer",
        results_base=args.results_base or "results",
        logs_base=args.logs_base or "logs",
    )
    
    # Load dataset
    test_set = load_by_name(args.dataset_name, max_samples=args.max_samples)
    print(f"Loaded {len(test_set['question'])} questions from {args.dataset_name}")
    
    # Prepare model config
    model_config = {
        'model_name': args.model_name,
        'inference_backend': args.inference_backend,
        'temperature': args.temperature,
        'top_p': args.top_p,
        'max_tokens': args.max_tokens,
        'suffix': args.suffix,
        'max_thinking_tokens': getattr(args, 'max_thinking_tokens', None),
        'google_api_key': getattr(args, 'google_api_key', None),
        'retry_attempts': getattr(args, 'retry_attempts', 5)
    }
    
    # Prepare grader config
    grader_config = {
        'backend': args.grader_backend,
        'model': args.grader_model
    }
    
    # Compute P(Answer) scores using sampling
    print(f"\nComputing P(Answer) scores with {args.n_samples_panswer} samples per question at T={args.panswer_temperature}")
    prompt_path = f"prompts/{args.prompt_template}.txt"
    
    # Use standard multi-process approach
    panswer_scores, detailed_samples = batch_compute_panswer_vllm(
        test_set=test_set,
        model_config=model_config,
        n_samples=args.n_samples_panswer,
        temperature=args.panswer_temperature,
        num_proc=args.num_proc,
        batch_size=args.batch_size,
        prompt_template_path=prompt_path,
        grader_config=grader_config,
        verbose=args.verbose
    )
    
    # Now get single predictions at the model's configured temperature for evaluation
    print(f"\nGenerating single predictions at T={args.temperature} for final evaluation")
    
    # Use standard predict_dataset
    predicted_dataset = predict_dataset(
        test_set,
        model_name=args.model_name,
        prompt_template_path=prompt_path,
        num_proc=args.num_proc,
        output_column="predicted_answer_raw",
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        suffix=args.suffix,
        top_p=args.top_p,
        max_thinking_tokens=getattr(args, 'max_thinking_tokens', None),
        inference_backend=args.inference_backend,
        google_api_key=getattr(args, 'google_api_key', None)
    )
    
    # Extract answers
    predicted_dataset = predicted_dataset.map(
        lambda x: {"predicted_answer": extract_answer(x["predicted_answer_raw"])}
    )
    
    # Grade single predictions
    graded_dataset = predict_dataset(
        predicted_dataset,
        model_name=args.grader_model,
        prompt_template_path="prompts/GRADER.txt",
        output_column="grade",
        num_proc=args.num_proc,
        max_tokens=10,
        inference_backend=args.grader_backend,
        google_api_key=getattr(args, 'google_api_key', None),
        suffix=None
    )
    
    # Create results DataFrame
    results_df = pd.DataFrame({
        'question': test_set['question'],
        'answer': test_set['answer'],
        'predicted_answer_raw': predicted_dataset['predicted_answer_raw'],
        'predicted_answer': predicted_dataset['predicted_answer'],
        'grade': [clean_output(grade) for grade in graded_dataset['grade']],
        'panswer_score': panswer_scores
    })
    
    # Create binary labels from single predictions
    binary_labels = np.array([1 if grade == 'A' else 0 for grade in results_df['grade'].values])
    
    # Print summary statistics
    grade_counts = Counter(results_df['grade'])
    print(f"\nGrade distribution: {grade_counts}")
    correct_rate = grade_counts.get('A', 0) / len(results_df)
    print(f"Correct rate (A): {correct_rate:.3f}")
    print(f"Incorrect rate (B): {grade_counts.get('B', 0) / len(results_df):.3f}")
    print(f"Refusal rate (C): {grade_counts.get('C', 0) / len(results_df):.3f}")
    print(f"Mean P(Answer): {np.mean(panswer_scores):.3f} (std: {np.std(panswer_scores):.3f})")
    
    # Calibration analysis using binary labels
    print("\n=== Calibration Analysis ===")
    print(calibration_summary(np.array(panswer_scores), binary_labels, args.model_config))
    
    # Compute calibration metrics without plotting
    from src.core.calibration_metrics import compute_all_metrics
    metrics = compute_all_metrics(np.array(panswer_scores), binary_labels, n_bins=10)
    
    # Save results to logs directory only
    csv_path = f"{logs_dir}/panswer_{args.dataset_name}_{args.model_config}.csv"
    results_df.to_csv(csv_path, index=False)
    print(f"\nSaved results to {csv_path}")
    
    # Save detailed sample data
    detailed_rows = []
    for sample_data in detailed_samples:
        question_idx = sample_data['question_idx']
        question = sample_data['question']
        ground_truth = sample_data['ground_truth']
        n_samples = sample_data['n_samples']
        
        # Create one row per sample
        for i in range(n_samples):
            row = {
                'question_idx': question_idx,
                'question': question,
                'ground_truth': ground_truth,
                'sample_idx': i,
                'sample_response_raw': sample_data['sample_responses'][i],
                'sample_predicted_answer': sample_data['sample_predicted_answers'][i],
                'sample_grade': sample_data['sample_grades'][i],
                'panswer_score': sample_data['panswer_score'],
                'answer_count': sample_data['answer_count'],
                'n_samples': n_samples
            }
            # Add grade distribution as separate columns
            for grade, count in sample_data['grade_distribution'].items():
                row[f'grade_{grade}_count'] = count
            detailed_rows.append(row)
    
    detailed_df = pd.DataFrame(detailed_rows)
    detailed_csv_path = f"{logs_dir}/panswer_detailed_{args.dataset_name}_{args.model_config}.csv"
    detailed_df.to_csv(detailed_csv_path, index=False)
    print(f"Saved detailed sample data to {detailed_csv_path}")
    
    # Save metrics and configuration with plotting data
    output_data = {
        "configuration": {
            "dataset_name": args.dataset_name,
            "model_config": args.model_config,
            "model_name": args.model_name,
            "max_samples": args.max_samples,
            "n_samples_panswer": args.n_samples_panswer,
            "panswer_temperature": args.panswer_temperature,
            "model_temperature": args.temperature,
            "prompt_template": args.prompt_template,
            "grader_backend": args.grader_backend,
            "grader_model": args.grader_model
        },
        "results": {
            "total_questions": len(results_df),
            "grade_distribution": dict(grade_counts),
            "correct_rate": correct_rate,
            "incorrect_rate": grade_counts.get('B', 0) / len(results_df),
            "refusal_rate": grade_counts.get('C', 0) / len(results_df),
            "mean_panswer": float(np.mean(panswer_scores)),
            "std_panswer": float(np.std(panswer_scores)),
            "panswer_scores": [float(score) for score in panswer_scores],
            "binary_labels": binary_labels.tolist()
        },
        "calibration_metrics": metrics
    }
    
    # Save to logs directory (original behavior)
    json_path = f"{logs_dir}/panswer_metrics_{args.dataset_name}_{args.model_config}.json"
    with open(json_path, 'w') as f:
        json.dump(output_data, f, indent=2)
    print(f"Saved metrics to {json_path}")
    
    # Also save a copy to results directory
    results_json_path = f"{results_dir}/panswer_metrics_{args.dataset_name}_{args.model_config}.json"
    with open(results_json_path, 'w') as f:
        json.dump(output_data, f, indent=2)
    print(f"Saved metrics copy to {results_json_path}")
    
    # Automatically generate plots
    print("\n" + "="*60)
    print("P(Answer) evaluation complete!")
    print(f"Data saved to: {logs_dir}")
    print("Generating plots automatically...")
    
    try:
        from src.visualization.calibration_plotting import load_and_plot_from_logs_dir
        load_and_plot_from_logs_dir(logs_dir)
        print("✅ Plots generated successfully!")
    except Exception as e:
        print(f"❌ Plot generation failed: {e}")
        print(f"You can manually run: python -m src.visualization.calibration_plotting {logs_dir} --from_logs")
    
    print("="*60)


if __name__ == "__main__":
    main()
