import torch

def _get_local_best_and_worst_samples(args, state_dict, accelerator, count, is_ddp=True):
    """Get the locally best and worst samples based on rewards.

    This function:
    1. Identifies local top and bottom rewards
    2. Filters local samples to keep only those above/below thresholds

    Args:
        args: Command line arguments
        state_dict: Dictionary containing rollout data including rewards
        accelerator: Accelerator for distributed training
        count: Current training iteration count
        is_ddp: Whether using DistributedDataParallel (multi-GPU) or not (not used in this function)

    Returns:
        Original state_dict modified in-place to contain only high-reward and low-reward samples
    """
    # Extract rewards from state_dict
    rewards = state_dict.get("final_reward", None)

    if rewards is None or len(rewards) == 0:
        accelerator.print("No rewards found in state_dict. Skipping sample selection.")
        return state_dict

    # Get local rewards tensor and get unique values
    local_rewards_tensor = rewards
    local_rewards_unique = torch.tensor(sorted(set(local_rewards_tensor.tolist())), device=local_rewards_tensor.device)

    # Find the top N and bottom N rewards to create thresholds
    num_best_to_keep = min(args.num_best_samples, len(local_rewards_unique))
    num_worst_to_keep = min(args.num_best_samples, len(local_rewards_unique))

    if num_best_to_keep <= 0 and num_worst_to_keep <= 0:
        # If no samples to keep, return empty state_dict
        accelerator.print("No samples to keep. Skipping sample selection.")
        for key in list(state_dict.keys()):
            if isinstance(state_dict[key], list):
                state_dict[key] = []
            elif isinstance(state_dict[key], torch.Tensor) and state_dict[key].dim() > 0:
                state_dict[key] = state_dict[key].new_empty((0,) + state_dict[key].shape[1:])
        return state_dict

    # Sort unique rewards to find thresholds
    sorted_rewards, _ = torch.sort(local_rewards_unique, descending=True)

    # Set thresholds for best and worst samples
    best_threshold = None
    worst_threshold = None

    if num_best_to_keep > 0:
        best_threshold = sorted_rewards[min(num_best_to_keep - 1, len(sorted_rewards) - 1)].item()

    if num_worst_to_keep > 0:
        worst_threshold = sorted_rewards[max(len(sorted_rewards) - num_worst_to_keep, 0)].item()

    # Log threshold information
    accelerator.print(f"Local best reward threshold: {best_threshold}, worst reward threshold: {worst_threshold}")

    # Get indices of samples to keep based on thresholds
    keep_indices = []
    for i, reward in enumerate(local_rewards_tensor):
        reward_val = reward.item() if hasattr(reward, 'item') else reward
        should_keep = False

        # Keep if it's a best sample
        if best_threshold is not None and reward_val >= best_threshold:
            should_keep = True

        # Keep if it's a worst sample
        if worst_threshold is not None and reward_val <= worst_threshold:
            should_keep = True

        if should_keep:
            keep_indices.append(i)

    # Log how many samples are being kept
    accelerator.print(f"Process {accelerator.process_index}: Keeping {len(keep_indices)} samples (best + worst)")

    # Filter state dict items to keep only selected indices
    for key in list(state_dict.keys()):
        if isinstance(state_dict[key], list):
            state_dict[key] = [state_dict[key][i] for i in keep_indices]
        elif isinstance(state_dict[key], torch.Tensor) and state_dict[key].dim() > 0 and state_dict[key].size(0) > 0:
            state_dict[key] = state_dict[key][keep_indices]

    print(f"Process {accelerator.process_index}: get {len(state_dict['final_reward'])} samples finally")

    return state_dict
