"""
Optimized public leaderboard view with efficient database queries.
"""
from django.shortcuts import render
from django.db.models import Prefetch, Q
from collections import defaultdict
from model_evaluation.models import (
    Model, ModelTier, ModelAttempt, ModelSubquestionAnswer,
    ModelGrading, ModelAnswer, ModelGradingSession
)
from questions.models import Question, QuestionState, Subquestion
import json


def optimized_public_leaderboard(request):
    """
    Optimized leaderboard view that minimizes database queries.
    """
    # Get active questions (single query)
    active_state = QuestionState.objects.filter(status='active').first()
    if not active_state:
        return render(request, 'questions/leaderboard.html', {
            'models': [],
            'progress_stats': None,
            'pairwise_data': None,
        })
    
    # Get benchmark questions with subquestions (single query with prefetch)
    # Exclude published questions from leaderboard
    included_questions = Question.objects.filter(
        status=active_state,
        benchmark_inclusion=True,
        published__isnull=True  # Exclude published questions
    ).prefetch_related('subquestion_set')
    
    question_ids = list(included_questions.values_list('id', flat=True))
    
    # Build question -> subquestions mapping
    question_subquestions = {}
    for q in included_questions:
        subqs = list(q.subquestion_set.all())
        if subqs:
            question_subquestions[q.id] = subqs
    
    # Get all active models (single query)
    models = list(Model.objects.filter(is_active=True).select_related('tier').order_by(
        'tier__tier_number', 'display_name'
    ))
    
    # Prefetch ALL attempts for benchmark questions (single query)
    all_attempts = ModelAttempt.objects.filter(
        question_id__in=question_ids,
        model__in=models
    ).select_related('model', 'model__tier').values(
        'id', 'model_id', 'question_id', 'attempt_number'
    )
    
    # Build lookup: (model_id, question_id, attempt_number) -> latest attempt_id
    attempt_lookup = {}
    for attempt in all_attempts:
        key = (attempt['model_id'], attempt['question_id'], attempt['attempt_number'])
        # Keep the latest (highest id) for each key
        if key not in attempt_lookup or attempt['id'] > attempt_lookup[key]:
            attempt_lookup[key] = attempt['id']
    
    # Get ALL subquestion answers for these attempts (single query)
    attempt_ids = list(attempt_lookup.values())
    all_subq_answers = ModelSubquestionAnswer.objects.filter(
        attempt_id__in=attempt_ids
    ).select_related('subquestion').values(
        'attempt_id', 'subquestion_id', 'is_correct', 'admin_override'
    )
    
    # Build lookup: attempt_id -> list of subquestion answers
    subq_answers_by_attempt = defaultdict(list)
    for answer in all_subq_answers:
        subq_answers_by_attempt[answer['attempt_id']].append(answer)
    
    # Get ALL gradings for benchmark questions from finalized sessions (single query)
    # Filter by finalized sessions to match scores page logic
    finalized_sessions = ModelGradingSession.objects.filter(
        question_id__in=question_ids,
        session_status='finalized'
    )

    all_gradings = ModelGrading.objects.filter(
        session__in=finalized_sessions,
        grading_status='completed',
        model_answer__released_for_grading=True,
        progress_grade__isnull=False
    ).select_related('model_answer', 'model_answer__model').values(
        'model_answer__model_id',
        'model_answer__model__display_name',
        'model_answer__question_id',
        'progress_grade'
    )
    
    # Build lookup: (model_id, question_id) -> list of grades
    grades_by_model_question = defaultdict(list)
    for grading in all_gradings:
        key = (grading['model_answer__model_id'], grading['model_answer__question_id'])
        grades_by_model_question[key].append(grading['progress_grade'])
    
    # Calculate total possible points across ALL benchmark questions
    total_possible_points = sum(
        sum(subq.points for subq in subquestions)
        for subquestions in question_subquestions.values()
    )

    # Now calculate everything using in-memory data

    # 1. Calculate subquestion scores
    models_with_scores = []
    model_subq_scores = {}  # For pairwise comparisons

    for model in models:
        question_percentages = []
        questions_attempted = 0
        model_question_scores = {}
        evaluated_questions = set()  # Track which questions have been evaluated

        for question_id, subquestions in question_subquestions.items():
            # Get attempts for this model/question
            if model.tier.tier_number == 1:
                attempt_ids_to_check = []
                for attempt_num in [1, 2]:
                    key = (model.id, question_id, attempt_num)
                    if key in attempt_lookup:
                        attempt_ids_to_check.append(attempt_lookup[key])
            else:
                key = (model.id, question_id, 1)
                attempt_ids_to_check = [attempt_lookup[key]] if key in attempt_lookup else []
            
            if not attempt_ids_to_check:
                continue
            
            # Calculate scores for each attempt
            attempt_percentages = []
            for attempt_id in attempt_ids_to_check:
                answers = subq_answers_by_attempt.get(attempt_id, [])
                if not answers:
                    continue
                
                points_earned = 0
                points_possible = 0
                
                # Match answers to subquestions
                answer_dict = {a['subquestion_id']: a for a in answers}
                for subq in subquestions:
                    points = subq.points
                    points_possible += points
                    
                    if subq.id in answer_dict:
                        answer = answer_dict[subq.id]
                        # Check effective correctness (admin_override or is_correct)
                        if answer['admin_override'] == 1 or (answer['admin_override'] is None and answer['is_correct']):
                            points_earned += points
                
                if points_possible > 0:
                    percentage = (points_earned / points_possible) * 100
                    attempt_percentages.append(percentage)
                    model_question_scores[question_id] = percentage  # For pairwise
            
            if attempt_percentages:
                avg_percentage = sum(attempt_percentages) / len(attempt_percentages)
                question_percentages.append(avg_percentage)
                questions_attempted += 1
                evaluated_questions.add(question_id)
        
        # Calculate base percentage: average assuming 0% on all unevaluated questions
        unevaluated_questions = set(question_subquestions.keys()) - evaluated_questions
        total_questions = len(question_subquestions)

        if total_questions > 0:
            # Sum of percentages for evaluated questions + 0% for unevaluated questions
            base_percentage = sum(question_percentages) / total_questions
        else:
            base_percentage = None

        # Calculate max percentage: average assuming 100% on all unevaluated questions
        if total_questions > 0:
            # Sum of percentages for evaluated questions + 100% for each unevaluated question
            total_percentage_sum = sum(question_percentages) + (len(unevaluated_questions) * 100.0)
            max_percentage = total_percentage_sum / total_questions
        else:
            max_percentage = None

        display_name = model.display_name
        if display_name == "Grok 4 Heavy":
            display_name = "Grok 4"

        models_with_scores.append({
            'model': model,
            'display_name': display_name,
            'tier': model.tier.tier_number,
            'score_percentage': base_percentage,  # Average of question percentages
            'max_percentage': max_percentage,      # Maximum achievable assuming 100% on remaining
            'questions_attempted': questions_attempted,
            'questions_total': len(question_subquestions),
            'has_unevaluated': len(unevaluated_questions) > 0,
            'framework_type': model.framework_type,
        })
        
        model_subq_scores[display_name] = model_question_scores
    
    # Sort by score
    models_with_scores.sort(
        key=lambda x: x['score_percentage'] if x['score_percentage'] is not None else -1,
        reverse=True
    )
    
    # 2. Calculate progress grades
    progress_models = []
    for model in models:
        display_name = model.display_name
        if display_name == "Grok 4 Heavy":
            display_name = "Grok 4"

        # Check if this model has any gradings
        has_grading = any(grades_by_model_question.get((model.id, qid), [])
                         for qid in question_ids)

        # Skip models without gradings
        if not has_grading:
            continue

        # Count questions in each progress category - ONLY for questions with grades
        progress_counts = {0: 0, 1: 0, 2: 0, 3: 0}

        for question_id in question_ids:
            key = (model.id, question_id)
            grades = grades_by_model_question.get(key, [])

            # Only count questions that have grades (matching scores page logic)
            if grades:
                avg_grade = sum(grades) / len(grades)
                rounded_grade = int(avg_grade + 0.5)
                progress_counts[rounded_grade] += 1

        total = sum(progress_counts.values())
        percentages = [
            100.0 * progress_counts[i] / total if total > 0 else 0
            for i in range(4)
        ]

        progress_models.append({
            'model_name': display_name,
            'tier': model.tier.tier_number,
            'percentages': percentages,
            'total_questions': total,
            'has_grading': True,  # Always true now since we filter above
            'framework_type': model.framework_type,
        })
    
    # Sort by complete solution percentage
    progress_models.sort(key=lambda x: x['percentages'][3], reverse=True)
    
    progress_stats = {'models': progress_models} if progress_models else None
    
    # 3. Calculate pairwise comparisons (include ALL models, ordered by subquestion score)
    # Use the same ordering as models_with_scores (already sorted by score descending)
    all_model_info = []
    for model_data in models_with_scores:
        display_name = model_data['model'].display_name
        if display_name == "Grok 4 Heavy":
            display_name = "Grok 4"
        all_model_info.append({
            'name': display_name,
            'tier': model_data['tier'],
            'score': model_data['score_percentage'],
            'framework_type': model_data['model'].framework_type,
        })
    
    # Build the pairwise matrix
    matrix = []
    for i, model1_info in enumerate(all_model_info):
        row = []
        model1 = model1_info['name']
        for j, model2_info in enumerate(all_model_info):
            model2 = model2_info['name']
            if i == j:
                row.append(None)
            else:
                wins = 0
                scores1 = model_subq_scores.get(model1, {})
                scores2 = model_subq_scores.get(model2, {})
                common_questions = set(scores1.keys()) & set(scores2.keys())
                
                for qid in common_questions:
                    if scores1[qid] > scores2[qid]:
                        wins += 1
                
                row.append(wins)
        matrix.append(row)
    
    pairwise_data = {
        'models': all_model_info,
        'matrix': matrix,
        'total_questions': len(question_ids),
        'total_subquestions': sum(len(sq) for sq in question_subquestions.values()),
    }
    
    # Convert to JSON
    progress_stats_json = json.dumps(progress_stats) if progress_stats else None
    
    models_json = []
    for model_data in models_with_scores:
        models_json.append({
            'model': {
                'display_name': model_data['display_name'],
                'id': model_data['model'].id,
            },
            'display_name': model_data['display_name'],
            'tier': model_data['tier'],
            'score_percentage': model_data['score_percentage'],
            'max_percentage': model_data['max_percentage'],
            'has_unevaluated': model_data['has_unevaluated'],
            'questions_attempted': model_data['questions_attempted'],
            'questions_total': model_data['questions_total'],
            'framework_type': model_data['framework_type'],
        })
    models_with_scores_json = json.dumps(models_json)
    
    pairwise_data_json = json.dumps(pairwise_data) if pairwise_data else None
    
    context = {
        'models': models_with_scores,
        'models_json': models_with_scores_json,
        'progress_stats': progress_stats,
        'progress_stats_json': progress_stats_json,
        'pairwise_data': pairwise_data,
        'pairwise_data_json': pairwise_data_json,
        'total_questions': len(question_ids),
        'total_models': len(models_with_scores),
    }
    
    return render(request, 'questions/leaderboard.html', context)