import torch
import numpy as np
import os
DISTRL_DEBUG_VIS = os.environ.get('DISTRL_DEBUG_VIS', None)
DISTRL_LOSS_ADDNOISE = os.environ.get('DISTRL_LOSS_ADDNOISE', None)
DISTRL_DEBUG_DETAIL_LOG = os.environ.get('DISTRL_DEBUG_DETAIL_LOG', None)
DISTRL_RL_NOTRAIN_LASTSTEP = os.environ.get('DISTRL_RL_NOTRAIN_LASTSTEP', None)
DISTRL_RL_GRPO_NO_STD = os.environ.get('DISTRL_RL_GRPO_NO_STD', None)

def get_random_indices(num_indices, sample_size):
    """Returns a random sample of indices from a larger list of indices.

    Args:
        num_indices (int): The total number of indices to choose from.
        sample_size (int): The number of indices to choose.

    Returns:
        A numpy array of `sample_size` randomly chosen indices.
    """
    return np.random.choice(num_indices, size=sample_size, replace=False)

if DISTRL_RL_NOTRAIN_LASTSTEP:
    DISTRL_RL_NOTRAIN_LASTSTEP = 1
else:
    DISTRL_RL_NOTRAIN_LASTSTEP = 0

from distrl.profiling import FakeProfilerTimer

# Import the visualization module if debug visualization is enabled
if DISTRL_DEBUG_VIS is not None:
    from . import visualize

if DISTRL_LOSS_ADDNOISE is not None:
    from torchvision import transforms
    transform = transforms.Compose([
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])

def _collect_rollout(args, pipe, policy_model, is_ddp, image_pool, state_dict, accelerator = None, count = 0):
    """Collects rollout trajectories from the environment.

    Args:
        args: Arguments.
        pipe: Diffusion pipeline.
        policy_model: Policy model (used as UNet).
        is_ddp: Whether using distributed data parallel.
        image_pool: ImagePool instance for FID calculation.
        state_dict: State dictionary to store collected data.
    """
    # Set pipe.unet to policy_model to use it for inference
    original_unet = pipe.unet
    pipe.unet = policy_model

    # Sample indices and prompts for replacement
    indices_groups, prompts_groups = image_pool.sample_replacement_indices(
        num_groups=args.num_groups,
        num_samples=args.num_samples
    )   # indices_groups: List(num_groups, List(num_samples))

    # Flatten all prompts and indices for batch processing
    all_prompts = []
    all_indices = []
    for group_prompts, group_indices in zip(prompts_groups, indices_groups):
        all_prompts.extend(group_prompts)
        all_indices.extend(group_indices)

    # Store all generated data
    # Initialize empty lists for collecting tensors
    all_latents_list = [[] for _ in range(args.num_inference_steps + 1)]  # +1 for initial noise
    all_unconditional_list = []
    all_guided_list = []
    all_log_probs_list = [[] for _ in range(args.num_inference_steps)]
    all_features = []

    # Process all prompts in batches
    for batch_start in range(0, len(all_prompts), args.g_batch_size):
        batch_end = min(batch_start + args.g_batch_size, len(all_prompts))
        batch_prompts = all_prompts[batch_start:batch_end]  # (g_batch_size, )

        # Collect rollout data from diffusion sampling process
        with torch.no_grad():
            # Get trajectories from the pipeline
            (
                images,  # ndarray[g_batch_size, 512, 512, 3]
                latents_list,   # List(num_inference_steps + 1, Tensor[g_batch_size, 4, 64, 64])
                unconditional_prompt_embeds,  # Tensor[g_batch_size, 77, 768]
                guided_prompt_embeds,   # Tensor[g_batch_size, 77, 768]
                log_prob_list,  # List(num_inference_steps, Tensor[g_batch_size])
                _,
            ) = pipe.forward_collect_traj_ddim(
                prompt=batch_prompts,
                is_ddp=is_ddp,
                num_inference_steps=args.num_inference_steps,
                guidance_scale=args.guidance_scale,
                output_type="pt"
            )

            # Extract inception features for new images
            new_features = image_pool.extract_inception_features(images)
            new_features = [feature.squeeze(0) for feature in new_features]     # List(g_batch_size, Tensor[2048])

            # Store batch results in lists
            for t in range(args.num_inference_steps + 1):
                all_latents_list[t].append(latents_list[t])
            all_unconditional_list.append(unconditional_prompt_embeds)
            all_guided_list.append(guided_prompt_embeds)
            for t in range(args.num_inference_steps):
                all_log_probs_list[t].append(log_prob_list[t])
            all_features.extend(new_features)

            # Clean up to save memory
            del images, latents_list, unconditional_prompt_embeds, guided_prompt_embeds, log_prob_list, new_features
            torch.cuda.empty_cache()

    # Concatenate all tensors
    all_latents = [torch.cat(tensors) for tensors in all_latents_list]  # List(num_inference_steps + 1, Tensor[num_groups * num_samples, 4, 64, 64])
    all_unconditional = torch.cat(all_unconditional_list)  # Tensor[num_groups * num_samples, 77, 768]
    all_guided = torch.cat(all_guided_list)  # Tensor[num_groups * num_samples, 77, 768]
    all_log_probs = [torch.cat(tensors) for tensors in all_log_probs_list]  # List(num_inference_steps, Tensor[num_groups * num_samples])

    # Calculate FID reward for each group
    group_rewards = []
    for group_idx in range(args.num_groups):
        start_idx = group_idx * args.num_samples
        end_idx = start_idx + args.num_samples
        group_features = all_features[start_idx:end_idx]
        group_indices = all_indices[start_idx:end_idx]

        # Calculate FID reward by replacing features in pool
        fid_reward = image_pool.compute_fid(
            temp_features=group_features,
            temp_indices=group_indices
        )
        group_rewards.append(-fid_reward)

    # Organize data by groups and update state_dict
    for group_idx in range(args.num_groups):
        start_idx = group_idx * args.num_samples
        end_idx = start_idx + args.num_samples

        # Get group data
        group_indices = all_indices[start_idx:end_idx]
        group_reward = group_rewards[group_idx]

        # Initialize lists to store tensors for this group
        group_states = []
        group_next_states = []
        group_unconditional = []
        group_guided = []
        group_log_probs = []

        # Process each timestep
        for t in range(args.num_inference_steps):
            # Get data for this timestep
            current_latents = all_latents[t][start_idx:end_idx]  # [num_samples, 4, 64, 64]
            next_latents = all_latents[t+1][start_idx:end_idx]  # [num_samples, 4, 64, 64]
            current_unconditional = all_unconditional[start_idx:end_idx]  # [num_samples, 77, 768]
            current_guided = all_guided[start_idx:end_idx]  # [num_samples, 77, 768]
            current_log_probs = all_log_probs[t][start_idx:end_idx]  # [num_samples]

            # Store tensors for this group
            group_states.append(current_latents)
            group_next_states.append(next_latents)
            group_unconditional.append(current_unconditional)
            group_guided.append(current_guided)
            group_log_probs.append(current_log_probs)

        # Stack tensors for this group
        group_states = torch.stack(group_states)  # [num_inference_steps, num_samples, 4, 64, 64]
        group_next_states = torch.stack(group_next_states)  # [num_inference_steps, num_samples, 4, 64, 64]
        group_unconditional = torch.stack(group_unconditional)  # [num_inference_steps, num_samples, 77, 768]
        group_guided = torch.stack(group_guided)  # [num_inference_steps, num_samples, 77, 768]
        group_log_probs = torch.stack(group_log_probs)  # [num_inference_steps, num_samples]

        # Create timestep tensor for this group (shared across samples)
        group_timesteps = torch.arange(args.num_inference_steps, device=state_dict["timestep"].device)

        # Create reward tensor for this group (shared across samples and timesteps)
        group_rewards = torch.full((args.num_inference_steps,), float(group_reward), device=state_dict["final_reward"].device)

        # Create pool indices tensor for this group (shared across timesteps)
        group_pool_indices = torch.tensor([group_indices] * args.num_inference_steps, device=state_dict["pool_indices"].device)

        # Update state_dict
        state_dict["state"] = torch.cat((state_dict["state"], group_states))
        state_dict["next_state"] = torch.cat((state_dict["next_state"], group_next_states))
        state_dict["timestep"] = torch.cat((state_dict["timestep"], group_timesteps))
        state_dict["final_reward"] = torch.cat((state_dict["final_reward"], group_rewards))
        state_dict["unconditional_prompt_embeds"] = torch.cat((state_dict["unconditional_prompt_embeds"], group_unconditional))
        state_dict["guided_prompt_embeds"] = torch.cat((state_dict["guided_prompt_embeds"], group_guided))
        state_dict["log_prob"] = torch.cat((state_dict["log_prob"], group_log_probs))
        state_dict["pool_indices"] = torch.cat((state_dict["pool_indices"], group_pool_indices))

    # Restore original unet
    pipe.unet = original_unet

def prepare_policy_samples(state_dict, args):
    """Prepare sample indices for policy training to avoid repeating samples.

    Args:
        state_dict: State dictionary containing collected data.
        args: Arguments.
    """

    # Get total number of samples
    num_samples = state_dict["state"].shape[0]

    # Calculate how many samples we need in total
    total_samples_needed = args.p_batch_size * args.gradient_accumulation_steps * args.p_step

    # Ensure we don't try to sample more than available
    if total_samples_needed > num_samples:
        total_samples_needed = num_samples

    # Generate non-repeating indices
    policy_indices = get_random_indices(num_samples, total_samples_needed)

    # Store indices in state dict
    state_dict["policy_indices"] = policy_indices
    state_dict["policy_index_position"] = 0

def _collect_rollout_flat(args, pipe, policy_model, is_ddp, image_pool, state_dict, accelerator = None, count = 0, profiler = FakeProfilerTimer()):
    """Collects rollout trajectories from the environment.

    Args:
        args: Arguments.
        pipe: Diffusion pipeline.
        policy_model: Policy model (used as UNet).
        is_ddp: Whether using distributed data parallel.
        image_pool: ImagePool instance for FID calculation.
        state_dict: State dictionary to store collected data.
    """
    # Set pipe.unet to policy_model to use it for inference
    original_unet = pipe.unet
    pipe.unet = policy_model

    # Sample indices and prompts for replacement
    indices_groups, prompts_groups = image_pool.sample_replacement_indices(
        num_groups=args.num_groups,
        num_samples=args.num_samples
    )   # indices_groups: List(num_groups, List(num_samples))

    if accelerator is not None and DISTRL_DEBUG_DETAIL_LOG:
        # Convert prompts to class IDs for logging
        class_ids_groups = []
        for group_prompts in prompts_groups:
            group_class_ids = sorted([pipe._extract_class_from_prompt(prompt) for prompt in group_prompts])
            class_ids_groups.append(group_class_ids)
        print(
            f"Process {accelerator.process_index}: "
            f"rollout class IDs [{len(class_ids_groups)}x{len(indices_groups[0])}]: {class_ids_groups}"
        )

    # Flatten all prompts and indices for batch processing
    all_prompts = []
    all_indices = []
    for group_prompts, group_indices in zip(prompts_groups, indices_groups):
        all_prompts.extend(group_prompts)
        all_indices.extend(group_indices)

    # Store all generated data
    # Initialize empty lists for collecting tensors
    all_latents_list = [[] for _ in range(args.num_inference_steps + 1 - DISTRL_RL_NOTRAIN_LASTSTEP)]  # +1 for initial noise
    all_unconditional_list = []
    all_guided_list = []
    all_log_probs_list = [[] for _ in range(args.num_inference_steps - DISTRL_RL_NOTRAIN_LASTSTEP)]
    all_features = []

    # Flag to track if visualization has been done
    vis_done = False

    # Process all prompts in batches
    for batch_start in range(0, len(all_prompts), args.g_batch_size):
        batch_end = min(batch_start + args.g_batch_size, len(all_prompts))
        batch_prompts = all_prompts[batch_start:batch_end]  # (g_batch_size, )

        # Collect rollout data from diffusion sampling process
        with torch.no_grad():
            profiler.start(f"Rollout Trajectory (step {count})")
            # Get trajectories from the pipeline
            (
                images,  # ndarray[g_batch_size, 512, 512, 3]
                latents_list,   # List(num_inference_steps + 1, Tensor[g_batch_size, 4, 64, 64])
                unconditional_prompt_embeds,  # Tensor[g_batch_size, 77, 768]
                guided_prompt_embeds,   # Tensor[g_batch_size, 77, 768]
                log_prob_list,  # List(num_inference_steps, Tensor[g_batch_size])
                _,
            ) = pipe.forward_collect_traj_ddim(
                prompt=batch_prompts,
                is_ddp=is_ddp,
                num_inference_steps=args.num_inference_steps,
                guidance_scale=args.guidance_scale,
                output_type="pt"
            )
            profiler.end(f"Rollout Trajectory (step {count})")

            # Visualize the first batch if DISTRL_DEBUG_VIS is enabled and not done yet
            if DISTRL_DEBUG_VIS is not None and not vis_done:
                # Only visualize on main process in distributed training
                should_visualize = count % args.save_interval == 0
                if is_ddp and accelerator is not None:
                    should_visualize = accelerator.is_main_process and should_visualize
                elif is_ddp:
                    # Skip visualization in non-main processes when distributed but no accelerator
                    should_visualize = False

                if should_visualize:
                    profiler.start(f"Visualize Denoising Process (step {count})")
                    visualize.visualize_denoising_process(
                        pipe=pipe,
                        latents_list=latents_list,
                        images=images,
                        prompts=batch_prompts,
                        output_dir=args.output_dir,
                        num_inference_steps=args.num_inference_steps,
                        count=count
                    )
                    vis_done = True
                    if accelerator is not None:
                        accelerator.print(f"Saved visualization to {args.output_dir}/debug/vis/step_{count}")
                    else:
                        print(f"Saved visualization to {args.output_dir}/debug/vis/step_{count}")
                    profiler.end(f"Visualize Denoising Process (step {count})")

            # Extract inception features for new images
            new_features = image_pool.extract_inception_features(images)
            new_features = [feature.squeeze(0) for feature in new_features]     # List(g_batch_size, Tensor[2048])

            # Store batch results in lists
            for t in range(args.num_inference_steps + 1 - DISTRL_RL_NOTRAIN_LASTSTEP):
                all_latents_list[t].append(latents_list[t])
            all_unconditional_list.append(unconditional_prompt_embeds)
            all_guided_list.append(guided_prompt_embeds)
            if DISTRL_LOSS_ADDNOISE is not None:
                # HACK: save images here
                # VAE decode image
                with torch.no_grad():
                    # VAE encode image
                    # images to [0, 1] tensor
                    images_vae = images / 255.0
                    images_vae = transform(images_vae)
                    images_vae = pipe.encode_images(images_vae.to(dtype=next(pipe.vae.parameters()).dtype)).detach().float().cpu()
                for t in range(args.num_inference_steps - DISTRL_RL_NOTRAIN_LASTSTEP):
                    all_log_probs_list[t].append(images_vae)
            else:
                # normal
                for t in range(args.num_inference_steps - DISTRL_RL_NOTRAIN_LASTSTEP):
                    all_log_probs_list[t].append(log_prob_list[t])
            all_features.extend(new_features)

            # Clean up to save memory
            del images, latents_list, unconditional_prompt_embeds, guided_prompt_embeds, log_prob_list, new_features
            torch.cuda.empty_cache()

    # Concatenate all tensors
    all_latents = [torch.cat(tensors) for tensors in all_latents_list]  # List(num_inference_steps + 1, Tensor[num_groups * num_samples, 4, 64, 64])
    all_unconditional = torch.cat(all_unconditional_list)  # Tensor[num_groups * num_samples, 77, 768]
    all_guided = torch.cat(all_guided_list)  # Tensor[num_groups * num_samples, 77, 768]
    all_log_probs = [torch.cat(tensors) for tensors in all_log_probs_list]  # List(num_inference_steps, Tensor[num_groups * num_samples])

    # Calculate FID reward for each group
    profiler.start(f"Calculate FID Reward (step {count})")
    group_rewards = []
    for group_idx in range(args.num_groups):
        start_idx = group_idx * args.num_samples
        end_idx = start_idx + args.num_samples
        group_features = all_features[start_idx:end_idx]
        group_indices = all_indices[start_idx:end_idx]

        # Calculate FID reward by replacing features in pool
        fid_reward = image_pool.compute_fid(
            temp_features=group_features,
            temp_indices=group_indices
        )

        group_rewards.append(-fid_reward)
    local_fids = [-x for x in group_rewards]
    profiler.end(f"Calculate FID Reward (step {count})")


    # Normalize rewards if grpo_flag is 2
    if args.grpo_flag == 2:
        profiler.start(f"Normalize Rewards (step {count})")
        # Convert rewards list to tensor for calculation
        rewards_tensor = torch.tensor(group_rewards, device='cuda')

        if is_ddp:
            # Distributed data parallel case: use rewards from all processes
            import torch.distributed as dist

            # Gather rewards from all processes
            world_size = dist.get_world_size()
            gathered_rewards_list = [torch.zeros_like(rewards_tensor) for _ in range(world_size)]

            # Perform the gathering operation
            dist.all_gather(gathered_rewards_list, rewards_tensor)

            # Concatenate all gathered rewards
            gathered_rewards = torch.cat(gathered_rewards_list, dim=0)

            # Calculate global mean and std from the combined rewards
            reward_mean = gathered_rewards.mean().item()
            reward_std = gathered_rewards.std().item()
        else:
            # Non-distributed case: use only local rewards
            gathered_rewards = rewards_tensor
            reward_mean = rewards_tensor.mean().item()
            reward_std = rewards_tensor.std().item()

        # If std is very small, use a minimum value to avoid division by zero
        reward_std = max(reward_std, 1e-8)

        if accelerator is not None and DISTRL_DEBUG_DETAIL_LOG:
            accelerator.print(f"reward_mean: {reward_mean}, reward_std: {reward_std}, gathered_rewards: {gathered_rewards}")

        # Normalize the rewards
        for i in range(len(group_rewards)):
            if DISTRL_RL_GRPO_NO_STD:
                group_rewards[i] = group_rewards[i] - reward_mean
            else:
                group_rewards[i] = (group_rewards[i] - reward_mean) / reward_std
        profiler.end(f"Normalize Rewards (step {count})")

    # Organize data by groups and update state_dict
    profiler.start(f"Organize Data by Groups (step {count})")
    for group_idx in range(args.num_groups):
        start_idx = group_idx * args.num_samples
        end_idx = start_idx + args.num_samples

        # Get group data
        group_indices = all_indices[start_idx:end_idx]
        group_reward = group_rewards[group_idx]

        # Initialize lists to store tensors for this group
        group_states = []
        group_next_states = []
        group_unconditional = []
        group_guided = []
        group_log_probs = []

        # Process each timestep
        for t in range(args.num_inference_steps - DISTRL_RL_NOTRAIN_LASTSTEP):
            # Get data for this timestep
            current_latents = all_latents[t][start_idx:end_idx]  # [num_samples, 4, 64, 64]
            next_latents = all_latents[t+1][start_idx:end_idx]  # [num_samples, 4, 64, 64]
            current_unconditional = all_unconditional[start_idx:end_idx]  # [num_samples, 77, 768]
            current_guided = all_guided[start_idx:end_idx]  # [num_samples, 77, 768]
            current_log_probs = all_log_probs[t][start_idx:end_idx]  # [num_samples]

            # Store tensors for this group
            # next_states_fixed = (current_latents - t_steps[0].view(-1, 1, 1, 1) * (next_latents - current_latents) / (t_steps[i+1] - t_steps[0]).view(-1, 1, 1, 1)).cuda()
            # unet_results = policy_model(current_latents[0].unsqueeze(0).cuda(), t_steps[0].cuda(), current_guided[0].unsqueeze(0).cuda())
            group_states.append(current_latents)
            group_next_states.append(next_latents)
            group_unconditional.append(current_unconditional)
            group_guided.append(current_guided)
            group_log_probs.append(current_log_probs)

        # Stack tensors for this group
        group_states = torch.stack(group_states)  # [num_inference_steps, num_samples, 4, 64, 64]
        group_next_states = torch.stack(group_next_states)  # [num_inference_steps, num_samples, 4, 64, 64]
        group_unconditional = torch.stack(group_unconditional)  # [num_inference_steps, num_samples, 77, 768]
        group_guided = torch.stack(group_guided)  # [num_inference_steps, num_samples, 77, 768]
        group_log_probs = torch.stack(group_log_probs)  # [num_inference_steps, num_samples]

        # Flatten tensors to [num_inference_steps * num_samples, ...] shape
        group_states = group_states.view(-1, *group_states.shape[2:])  # [num_inference_steps * num_samples, 4, 64, 64]
        group_next_states = group_next_states.view(-1, *group_next_states.shape[2:])  # [num_inference_steps * num_samples, 4, 64, 64]
        group_unconditional = group_unconditional.view(-1, *group_unconditional.shape[2:])  # [num_inference_steps * num_samples, 77, 768]
        group_guided = group_guided.view(-1, *group_guided.shape[2:])  # [num_inference_steps * num_samples, 77, 768]
        # group_log_probs = group_log_probs.view(-1)  # [num_inference_steps * num_samples]
        group_log_probs = group_log_probs.view(-1, *group_log_probs.shape[2:])  # TODO: temp, to keep feature maps in log_probs field

        # Create timestep tensor for this group and expand to match samples
        batch_timestep = torch.arange(args.num_inference_steps - DISTRL_RL_NOTRAIN_LASTSTEP, device=state_dict["timestep"].device)  # [num_inference_steps]
        batch_timestep = batch_timestep.unsqueeze(1).expand(-1, args.num_samples)  # [num_inference_steps, num_samples]
        group_timesteps = batch_timestep.reshape(-1)  # [num_inference_steps * num_samples]

        # Create reward tensor for this group and expand to match samples
        batch_final_reward = torch.full((args.num_inference_steps - DISTRL_RL_NOTRAIN_LASTSTEP,), float(group_reward), device=state_dict["final_reward"].device)  # [num_inference_steps]
        batch_final_reward = batch_final_reward.unsqueeze(1).expand(-1, args.num_samples)  # [num_inference_steps, num_samples]
        group_rewards_tensor = batch_final_reward.reshape(-1)  # [num_inference_steps * num_samples]

        # Create pool indices tensor for this group and expand to match samples
        group_pool_indices = torch.tensor([group_indices] * (args.num_inference_steps - DISTRL_RL_NOTRAIN_LASTSTEP), device=state_dict["pool_indices"].device)  # [num_inference_steps, num_samples]
        group_pool_indices = group_pool_indices.reshape(-1)  # [num_inference_steps * num_samples]

        # Update state_dict
        state_dict["state"] = torch.cat((state_dict["state"], group_states))
        state_dict["next_state"] = torch.cat((state_dict["next_state"], group_next_states))
        state_dict["timestep"] = torch.cat((state_dict["timestep"], group_timesteps))
        state_dict["final_reward"] = torch.cat((state_dict["final_reward"], group_rewards_tensor))
        state_dict["unconditional_prompt_embeds"] = torch.cat((state_dict["unconditional_prompt_embeds"], group_unconditional))
        state_dict["guided_prompt_embeds"] = torch.cat((state_dict["guided_prompt_embeds"], group_guided))
        state_dict["log_prob"] = torch.cat((state_dict["log_prob"], group_log_probs))
        state_dict["pool_indices"] = torch.cat((state_dict["pool_indices"], group_pool_indices))

        assert state_dict["state"].shape[0] == state_dict["next_state"].shape[0] == state_dict["timestep"].shape[0] == state_dict["final_reward"].shape[0] == state_dict["unconditional_prompt_embeds"].shape[0] == state_dict["guided_prompt_embeds"].shape[0] == state_dict["log_prob"].shape[0] == state_dict["pool_indices"].shape[0], f"state: {state_dict['state'].shape}, next_state: {state_dict['next_state'].shape}, timestep: {state_dict['timestep'].shape}, final_reward: {state_dict['final_reward'].shape}, unconditional_prompt_embeds: {state_dict['unconditional_prompt_embeds'].shape}, guided_prompt_embeds: {state_dict['guided_prompt_embeds'].shape}, log_prob: {state_dict['log_prob'].shape}, pool_indices: {state_dict['pool_indices'].shape} mismatch"
    profiler.end(f"Organize Data by Groups (step {count})")

    # Restore original unet
    pipe.unet = original_unet

    # return gathered_rewards or group_rewards
    if args.grpo_flag == 2:
        return {
            "fid": gathered_rewards.tolist(),
            "max_reward": ((gathered_rewards.max() - reward_mean) / reward_std).item(),
            "reward_mean": reward_mean,
            "reward_std": reward_std,
            "local_fids": local_fids
        }
    else:
        return {
            "fid": group_rewards,
            "max_reward": max(group_rewards),
            "local_fids": local_fids
        }
