import string

import torch
import scipy.stats


def group_advantages(name, rewards, num_options):
    num_samples = len(rewards)
    num_correct = rewards.count(1.0)
    rewards = torch.as_tensor(rewards)

    if num_correct in [0, num_samples]:
        return torch.zeros_like(rewards)

    if name.startswith('grpo'):
        advantages = get_advantages_grpo(rewards)
    elif name.startswith('drgrpo'):
        advantages = get_advantages_drgrpo(rewards)
    else:
        raise NotImplementedError()

    if name.endswith('_mcq'):
        num_not_correct = num_samples - num_correct
        prob_better_than_random = get_prob_solvable(num_options, num_correct, num_not_correct)
        advantages *= prob_better_than_random

    return advantages


def get_advantages_grpo(rewards):
    mean = rewards.mean()
    std = rewards.std()
    advantages = (rewards - mean) / std
    return advantages


def get_advantages_drgrpo(rewards):
    mean = rewards.mean()
    advantages = rewards - mean
    return advantages


def get_prob_solvable(num_options, num_correct, num_incorrect, prior=1.0):
    random_mean = 1 / num_options

    alpha = prior + num_correct
    beta = prior + num_incorrect

    prob_better_than_random = scipy.stats.beta.sf(random_mean, alpha, beta)
    return prob_better_than_random
