from typing import List, Union
import os
import torch
from collections import defaultdict
import numpy as np

from slime.utils.types import Sample

step_reward_token_id = int(os.environ.get("STEP_REWARD_TOKEN_ID", 151670))

def get_grpo_returns(
    rewards: torch.Tensor,
    kl: list[torch.Tensor],
    responses: list[torch.Tensor],
):
    # CP == 1, grpo or gspo
    returns = []
    for i in range(len(rewards)):
        reward_list = rewards[i]
        response = responses[i]

        step_rewards_tensor = torch.zeros_like(kl[i], dtype=torch.float32)

        # 处理找不到</round>的位置
        reward_token_indices = torch.where(response == step_reward_token_id)[0]
        if reward_token_indices.size(0) == 0:
            print("Warning: <step_reward_token_id> not found")
            assert len(reward_list) == 1, f"{len(reward_list)=} should be 1 when <step_reward_token_id> not found"
            step_rewards_tensor[:] = reward_list[0]
            returns.append(step_rewards_tensor)
            continue

        assert reward_token_indices.size(0) == len(reward_list), f"{reward_token_indices.size(0)=} != {len(reward_list)=} | {reward_list=} | {response.tolist()=}"

        start_idx = 0
        for j, end_idx_tensor in enumerate(reward_token_indices):
            end_idx = end_idx_tensor.item()
            step_rewards_tensor[start_idx : end_idx + 1] = reward_list[j]
            start_idx = end_idx + 1

        # 5. 处理 "DD" 部分 (最后一个分隔符之后的所有 tokens)
        #    检查是否还有剩余的 tokens 需要填充
        if start_idx < step_rewards_tensor.size(0):
            step_rewards_tensor[start_idx:] = reward_list[-1]

        returns.append(step_rewards_tensor)
    return returns

def compute_step_advantages_and_returns(args, rewards, kl, loss_masks, response_lengths, total_lengths, responses):
    
    if args.advantage_estimator in ["grpo", "gspo"]:
        returns = get_grpo_returns(rewards, kl, responses)
        # TODO: is the copy necessary?
        advantages = [r for r in returns]

    elif args.advantage_estimator == "reinforce_plus_plus":
        returns = get_reinforce_plus_plus_returns(
            rewards=rewards,
            kl=kl,
            loss_masks=loss_masks,
            response_lengths=response_lengths,
            total_lengths=total_lengths,
            kl_coef=args.kl_coef,
            gamma=args.gamma,
        )
        advantages = [r for r in returns]

    elif args.advantage_estimator == "reinforce_plus_plus_baseline":
        advantages = get_reinforce_plus_plus_baseline_advantages(
            rewards=rewards,
            kl=kl,
            loss_masks=loss_masks,
            kl_coef=args.kl_coef,
        )
        returns = advantages

    else:
        raise NotImplementedError(f"advantage_estimator {args.advantage_estimator} is not supported. ")
    
    return advantages, returns

def group_step_reward_normalization(args, rewards: List[List[float]], group_ids: List, epsilon: float = 1e-6) -> List[List[float]]:
    """
    Performs group-wise normalization on rewards, where each reward is a list of step-rewards.

    Args:
        args: Configuration object. Must contain:
              - advantage_estimator (str)
              - rewards_normalization (bool)
              - n_samples_per_prompt (int)
              - grpo_std_normalization (bool)
        rewards (List[List[float]]): A list where each element is a list of rewards for each step of a sample.
        group_ids (list): A flat list of group identifiers, aligned with rewards.
        epsilon (float): A small value to avoid division by zero.

    Returns:
        List[List[float]]: A new list of normalized rewards, preserving the original nested structure.
    """

    if len(rewards) != len(group_ids):
        raise ValueError("Length of rewards and group_ids must be the same.")
    
    if args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] and args.rewards_normalization:
        id_to_scores = defaultdict(list)
        for group_id, reward_list in zip(group_ids, rewards):
            id_to_scores[group_id].append(reward_list)

        id_to_mean, id_to_std = {}, {}
        for group_id, list_of_reward_lists in id_to_scores.items():
            assert len(list_of_reward_lists) >= args.n_samples_per_prompt, f"Group {group_id} has less than {args.n_samples_per_prompt} samples."

            scores_tensor = torch.tensor([score for sublist in list_of_reward_lists for score in sublist], dtype=torch.float32)
            id_to_mean[group_id] = torch.mean(scores_tensor)
            if args.advantage_estimator in ["grpo", "gspo"] and args.grpo_std_normalization:
                id_to_std[group_id] = torch.std(scores_tensor)
        
        normalized_rewards = []
        for group_id, reward_list in zip(group_ids, rewards):
            mean = id_to_mean[group_id]

            reward_tensor = torch.tensor(reward_list, dtype=torch.float32)
            advantage_tensor = reward_tensor - mean

            if args.advantage_estimator in ["grpo", "gspo"] and args.grpo_std_normalization:
                std = id_to_std[group_id]
                advantage_tensor = advantage_tensor / (std + epsilon)
            
            normalized_rewards.append(advantage_tensor.tolist())
    
    else:
        normalized_rewards = [list(r) for r in rewards]
     
    return normalized_rewards

def custom_reward_post_process_func(args, samples: Union[list[Sample], list[list[Sample]]]):
    # assert not args.balance_data, \
    #     "balance_data=True is incompatible with the current dynamic data generation. " \
    #     "Please set --balance-data False in your arguments."

    raw_rewards = [sample.get_reward_value(args) for sample in samples]
    group_ids = [sample.metadata['instance_id'] for sample in samples]

    rewards = group_step_reward_normalization(args, raw_rewards, group_ids)

    return raw_rewards, rewards