from typing import Optional, Union, Any

import random
import torch
import numpy as np

from collections import defaultdict

from slime.utils.types import Sample

def group_adv_normalization(args, rewards: list, group_ids: list, epsilon: float = 1e-6) -> list:
    """
    Performs group-wise normalization on rewards, where each group has different group size.

    Args:
        args: Configuration object. Must contain:
              - advantage_estimator (str)
              - rewards_normalization (bool)
              - n_samples_per_prompt (int)
              - grpo_std_normalization (bool)
        rewards (list): A flat list of reward values.
        group_ids (list): A flat list of group identifiers, aligned with rewards.
        epsilon (float): A small value to avoid division by zero.

    Returns:
        list: A new list of normalized rewards.
    """
    if len(rewards) != len(group_ids):
        raise ValueError("Length of rewards and group_ids must be the same.")
    if args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] and args.rewards_normalization:
        id_to_scores = defaultdict(list)
        for group_id, reward in zip(group_ids, rewards):
            id_to_scores[group_id].append(reward)

        id_to_mean, id_to_std = {}, {}
        for group_id, score_list in id_to_scores.items():
            assert len(score_list) >= args.n_samples_per_prompt, f"Group {group_id} has less than {args.n_samples_per_prompt} samples."

            scores_tensor = torch.tensor(score_list, dtype=torch.float32)
            id_to_mean[group_id] = torch.mean(scores_tensor)
            if args.advantage_estimator in ["grpo", "gspo"] and args.grpo_std_normalization:
                id_to_std[group_id] = torch.std(scores_tensor)
        
        normalized_rewards = []
        for group_id, reward in zip(group_ids, rewards):
            mean = id_to_mean[group_id]
            advantage = reward - mean
            if args.advantage_estimator in ["grpo", "gspo"] and args.grpo_std_normalization:
                std = id_to_std[group_id]
                advantage = advantage / (std + epsilon)
                
            normalized_rewards.append(advantage.item())
    
    else:
        normalized_rewards = rewards
            
    return normalized_rewards



def custom_reward_post_process_func(args, samples: Union[list[Sample], list[list[Sample]]]):
    assert not args.balance_data, \
        "balance_data=True is incompatible with the current dynamic data generation. " \
        "Please set --balance-data False in your arguments."

    raw_rewards = [sample.get_reward_value(args) for sample in samples]
    group_ids = [sample.metadata['instance_id'] for sample in samples]

    rewards = group_adv_normalization(args, raw_rewards, group_ids)

    return raw_rewards, rewards
