import json
from pathlib import Path
from typing import Dict, Any, Tuple, List


def analyze_motivation_reasoning_results(results_file: str) -> Dict[str, Any]:
    """Analyze the motivation_inference_results_reasoning_eval.json file and
    compute the number of responses that are both correct and independent, as well as
    how many unique samples these involve.

    Args:
        results_file: Path to the motivation_inference_results_reasoning_eval.json file

    Returns:
        A dictionary with aggregated statistics:
        {
            "total_samples": total number of samples,
            "total_responses": total number of responses,
            "correct_and_independent_responses": number of responses that are both correct and independent,
            "samples_with_correct_and_independent": number of samples that contain at least one correct and independent response,
            "samples_with_all_correct_and_independent": number of samples where all responses are both correct and independent,
            "correct_and_independent_rate_per_response": rate of correct-and-independent per response,
            "sample_level_any_correct_and_independent_rate": sample-level coverage rate (any),
            "sample_level_all_correct_and_independent_rate": sample-level all rate (all),
            "detailed_sample_stats": detailed stats for each sample
        }
    """
    results_path = Path(results_file)
    if not results_path.exists():
        raise FileNotFoundError(f"Results file not found: {results_file}")
    
    with open(results_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    # Load detailed results
    detailed_results = data.get("detailed_results", [])
    if not detailed_results:
        raise ValueError("No detailed_results found in the results file")
    
    # Counters
    total_samples = len(detailed_results)
    total_responses = 0
    correct_and_independent_responses = 0
    samples_with_correct_and_independent = 0
    samples_with_all_correct_and_independent = 0
    detailed_sample_stats = []
    
    for sample_idx, sample in enumerate(detailed_results):
        evaluations = sample.get("evaluations", [])
        sample_responses = len(evaluations)
        sample_correct_and_independent = 0
        
        # Count responses that are both correct and independent in the current sample
        for eval_result in evaluations:
            is_correct = eval_result.get("is_correct_reasoning", False)
            is_independent = eval_result.get("is_independent", False)
            
            if is_correct and is_independent:
                correct_and_independent_responses += 1
                sample_correct_and_independent += 1
        
        total_responses += sample_responses
        
        # Sample-level stats
        has_any_correct_and_independent = sample_correct_and_independent > 0
        has_all_correct_and_independent = (
            sample_correct_and_independent == sample_responses and sample_responses > 0
        )
        
        if has_any_correct_and_independent:
            samples_with_correct_and_independent += 1
        
        if has_all_correct_and_independent:
            samples_with_all_correct_and_independent += 1
        
        # Record detailed stats
        detailed_sample_stats.append({
            "sample_id": sample.get("sample_id", sample_idx),
            "total_responses": sample_responses,
            "correct_and_independent_responses": sample_correct_and_independent,
            "has_any_correct_and_independent": has_any_correct_and_independent,
            "has_all_correct_and_independent": has_all_correct_and_independent,
            "question_preview": (sample.get("question", "") or "")[:100] + "..." if len(sample.get("question", "") or "") > 100 else sample.get("question", ""),
        })
    
    # Compute rates
    correct_and_independent_rate_per_response = (
        correct_and_independent_responses / total_responses if total_responses > 0 else 0.0
    )
    sample_level_any_rate = (
        samples_with_correct_and_independent / total_samples if total_samples > 0 else 0.0
    )
    sample_level_all_rate = (
        samples_with_all_correct_and_independent / total_samples if total_samples > 0 else 0.0
    )
    
    return {
        "total_samples": total_samples,
        "total_responses": total_responses,
        "correct_and_independent_responses": correct_and_independent_responses,
        "samples_with_correct_and_independent": samples_with_correct_and_independent,
        "samples_with_all_correct_and_independent": samples_with_all_correct_and_independent,
        "correct_and_independent_rate_per_response": correct_and_independent_rate_per_response,
        "sample_level_any_correct_and_independent_rate": sample_level_any_rate,
        "sample_level_all_correct_and_independent_rate": sample_level_all_rate,
        "detailed_sample_stats": detailed_sample_stats,
    }


def print_motivation_analysis(results_file: str, show_detailed: bool = False):
    """Print an analysis report for Motivation reasoning results.

    Args:
        results_file: Path to motivation_inference_results_reasoning_eval.json
        show_detailed: Whether to display per-sample detailed statistics
    """
    try:
        stats = analyze_motivation_reasoning_results(results_file)
        
        print("=" * 80)
        print("Motivation Reasoning Results Analysis Report")
        print("=" * 80)
        print(f"Source file: {results_file}")
        print()
        
        print("📊 Overall Stats:")
        print(f"  • Total samples: {stats['total_samples']}")
        print(f"  • Total responses: {stats['total_responses']}")
        print(f"  • Correct AND independent responses: {stats['correct_and_independent_responses']}")
        print()
        
        print("🎯 Key Metrics:")
        print(f"  • Correct-and-independent rate per response: {stats['correct_and_independent_rate_per_response']:.3f} ({stats['correct_and_independent_responses']}/{stats['total_responses']})")
        print(f"  • Samples containing any correct-and-independent response: {stats['samples_with_correct_and_independent']}")
        print(f"  • Sample-level coverage (any): {stats['sample_level_any_correct_and_independent_rate']:.3f} ({stats['samples_with_correct_and_independent']}/{stats['total_samples']})")
        print(f"  • Samples where all responses are correct-and-independent: {stats['samples_with_all_correct_and_independent']}")
        print(f"  • Sample-level all rate: {stats['sample_level_all_correct_and_independent_rate']:.3f} ({stats['samples_with_all_correct_and_independent']}/{stats['total_samples']})")
        print()
        
        if show_detailed:
            print("📋 Detailed per-sample stats:")
            print("-" * 80)
            for sample_stat in stats['detailed_sample_stats']:
                print(f"Sample {sample_stat['sample_id']}:")
                print(f"  Responses: {sample_stat['total_responses']}")
                print(f"  Correct AND independent responses: {sample_stat['correct_and_independent_responses']}")
                print(f"  Contains any correct-and-independent: {'Yes' if sample_stat['has_any_correct_and_independent'] else 'No'}")
                print(f"  All responses correct-and-independent: {'Yes' if sample_stat['has_all_correct_and_independent'] else 'No'}")
                print(f"  Question preview: {sample_stat['question_preview']}")
                print("-" * 40)
        
        print("=" * 80)
        
    except Exception as e:
        print(f"Analysis failed: {e}")


def find_samples_with_correct_and_independent_responses(results_file: str) -> List[Dict[str, Any]]:
    """Find samples that contain at least one correct-and-independent response.

    Args:
        results_file: Path to motivation_inference_results_reasoning_eval.json

    Returns:
        A list of samples containing any correct-and-independent responses
    """
    stats = analyze_motivation_reasoning_results(results_file)
    return [
        sample for sample in stats['detailed_sample_stats']
        if sample['has_any_correct_and_independent']
    ]


def compare_correct_vs_independent(results_file: str) -> Dict[str, Any]:
    """Compare the distribution of correctness vs. independence.

    Args:
        results_file: Path to motivation_inference_results_reasoning_eval.json

    Returns:
        A dictionary containing comparison statistics
    """
    results_path = Path(results_file)
    with open(results_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    detailed_results = data.get("detailed_results", [])
    
    # Count the four combinations
    correct_and_independent = 0
    correct_but_not_independent = 0
    independent_but_not_correct = 0
    neither_correct_nor_independent = 0
    total_responses = 0
    
    for sample in detailed_results:
        for eval_result in sample.get("evaluations", []):
            is_correct = eval_result.get("is_correct_reasoning", False)
            is_independent = eval_result.get("is_independent", False)
            
            total_responses += 1
            
            if is_correct and is_independent:
                correct_and_independent += 1
            elif is_correct and not is_independent:
                correct_but_not_independent += 1
            elif not is_correct and is_independent:
                independent_but_not_correct += 1
            else:
                neither_correct_nor_independent += 1
    
    return {
        "total_responses": total_responses,
        "correct_and_independent": correct_and_independent,
        "correct_but_not_independent": correct_but_not_independent,
        "independent_but_not_correct": independent_but_not_correct,
        "neither_correct_nor_independent": neither_correct_nor_independent,
        "correct_and_independent_rate": correct_and_independent / total_responses if total_responses > 0 else 0.0,
        "correct_but_not_independent_rate": correct_but_not_independent / total_responses if total_responses > 0 else 0.0,
        "independent_but_not_correct_rate": independent_but_not_correct / total_responses if total_responses > 0 else 0.0,
        "neither_rate": neither_correct_nor_independent / total_responses if total_responses > 0 else 0.0,
    }


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Analyze Motivation reasoning results")
    parser.add_argument("results_file", help="Path to motivation_inference_results_reasoning_eval.json")
    parser.add_argument("--detailed", action="store_true", help="Show per-sample detailed statistics")
    parser.add_argument("--compare", action="store_true", help="Compare the distribution of correctness vs. independence")
    
    args = parser.parse_args()
    
    print_motivation_analysis(args.results_file, show_detailed=args.detailed)
    
    if args.compare:
        print("\n" + "=" * 80)
        print("Correctness vs Independence Distribution Comparison")
        print("=" * 80)
        comparison = compare_correct_vs_independent(args.results_file)
        print(f"Total responses: {comparison['total_responses']}")
        print(f"Correct AND independent: {comparison['correct_and_independent']} ({comparison['correct_and_independent_rate']:.3f})")
        print(f"Correct but NOT independent: {comparison['correct_but_not_independent']} ({comparison['correct_but_not_independent_rate']:.3f})")
        print(f"Independent but NOT correct: {comparison['independent_but_not_correct']} ({comparison['independent_but_not_correct_rate']:.3f})")
        print(f"Neither correct nor independent: {comparison['neither_correct_nor_independent']} ({comparison['neither_rate']:.3f})")