import numpy as np
from collections import defaultdict

from utils.grpo import get_reward

def get_eval_metrics(generation_outputs, tokenizer):
    answered = 0
    total_samples = 0
    total_examples = 0
    correct = 0
    correct_at_least_one = 0
    stopping_max_length = 0
    completions = defaultdict(list)
    entropy = list()
    for example_idx, examples in generation_outputs.items():
        has_one_correct = False
        for example in examples:
            answer_oracle = example['answer']
            generated_ids = example['generated_ids']
            completion = tokenizer.decode(generated_ids, skip_special_tokens=True)

            _, answer_model, is_correct, is_answered = get_reward(completion, answer_oracle)
            completions[example_idx].append((answer_model, answer_oracle, completion, len(generated_ids)))

            entropy.append(np.mean(example['entropy']))
            if example.get('stopping_reason') == 'max_length':
                stopping_max_length += 1

            if is_answered:
                answered += 1

            total_samples += 1
            if is_correct:
                correct += 1
                has_one_correct = True
        total_examples += 1
        if has_one_correct:
            correct_at_least_one += 1

    metrics = dict()
    metrics['answered'] = round(answered / total_samples * 100, 3)
    metrics['accuracy'] = round(correct / total_samples * 100, 3)
    if answered:
        metrics['accuracy_answered'] = round(correct / answered * 100, 3)
    metrics['pass@N'] = round(correct_at_least_one / total_examples * 100, 3)
    metrics['entropy'] = np.mean(entropy)
    metrics['stopping_max_length'] = stopping_max_length
    return metrics, completions
