from typing import Optional

import ray
import random
import torch
import torch.distributed as dist

from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
from slime.utils.timer import Timer
from slime.utils.misc import load_function


class DataIterator:
    def __init__(
        self,
        rollout_data,
        micro_batch_size: Optional[int] = None,
        micro_batch_indices: Optional[list[list[int]]] = None,
    ):
        self.rollout_data = rollout_data
        self.micro_batch_size = micro_batch_size
        self.micro_batch_indices = micro_batch_indices
        assert micro_batch_size is None or micro_batch_indices is None
        self.offset = 0

    def get_next(self, keys):
        batch = {}
        for key in keys:
            vals = self.rollout_data.get(key, None)
            if vals is None:
                batch[key] = None
            else:
                if self.micro_batch_indices is not None:
                    indices = self.micro_batch_indices[self.offset]
                    batch[key] = [vals[i] for i in indices]
                else:
                    assert self.offset + self.micro_batch_size <= len(
                        vals
                    ), f"offset: {self.offset}, micro_batch_size: {self.micro_batch_size}, len(vals): {len(vals)}"
                    batch[key] = vals[self.offset : self.offset + self.micro_batch_size]

        if self.micro_batch_indices is not None:
            self.offset += 1
        else:
            self.offset += self.micro_batch_size
        return batch

    def reset(self):
        self.offset = 0
        return self


def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu, cp_size):
    # use first fit to get the number of micro batches
    max_tokens_per_gpu *= cp_size
    batches = []
    for l in total_lengths:
        for i in range(len(batches)):
            if batches[i] + l <= max_tokens_per_gpu:
                batches[i] += l
                break
        else:
            batches.append(l)

    return len(batches)


def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size):
    rollout_data = {}

    rank = dist.get_rank()
    if rank == 0:
        data = ray.get(rollout_data_ref.inner)
        dist.broadcast_object_list([data], src=0)
    else:
        data = [None]
        dist.broadcast_object_list(data, src=0)
        data = data[0]

    # save the unprocessed reward for logging
    rollout_data["raw_reward"] = data["raw_reward"]

    if args.iterResearch:
        assert not args.balance_data, \
        "balance_data=True is incompatible with the current dynamic data generation. " \
        "Please set --balance-data False in your arguments."

        # do balancing
        num_samples = len(data["rewards"])
        num_to_keep = (num_samples // dp_size) * dp_size
        if num_to_keep < num_samples:
            num_truncated = num_samples - num_to_keep
            print(f"Warning: Randomly Truncating {num_truncated} samples to align batch for DP. "
                f"Total samples: {num_samples}, Keeping: {num_to_keep}, DP size: {dp_size}")
            
            all_indices = list(range(num_samples))
            indices_to_keep = random.sample(all_indices, num_to_keep)
            indices_to_keep.sort()

            keys_to_truncate = [key for key, val in data.items() if isinstance(val, list) and len(val) == num_samples]

            for key in keys_to_truncate:
                original_list = data[key]
                data[key] = [original_list[i] for i in indices_to_keep]
        
        assert len(data["rewards"]) % dp_size == 0, \
        f"Data alignment failed! Samples: {len(data['rewards'])}, DP size: {dp_size}"
        ###### END MODIFICATION

    total_lengths = [len(t) for t in data["tokens"]]
    data["total_lengths"] = total_lengths

    # save the seqlen of the whole rollout batch
    Timer().seq_lens = total_lengths

    if args.balance_data:
        # Group-aware partitioning to keep each group together
        n_samples_per_prompt = getattr(args, "n_samples_per_prompt", 1)
        # Calculate group-level lengths (sum of lengths for each group)
        num_groups = len(total_lengths) // n_samples_per_prompt
        group_lengths = []
        for i in range(num_groups):
            start_idx = i * n_samples_per_prompt
            end_idx = start_idx + n_samples_per_prompt
            group_total_length = sum(total_lengths[start_idx:end_idx])
            group_lengths.append(group_total_length)

        # Get partitions at group level
        group_partitions = get_seqlen_balanced_partitions(group_lengths, dp_size, equal_size=True)

        # Expand group partitions to trajectory level
        parititions = []
        for dp_rank_groups in group_partitions:
            trajectory_indices = []
            for group_idx in dp_rank_groups:
                # Add all trajectories in this group
                start_idx = group_idx * n_samples_per_prompt
                end_idx = start_idx + n_samples_per_prompt
                trajectory_indices.extend(range(start_idx, end_idx))
            parititions.append(trajectory_indices)

    def get_partition(val):
        if args.balance_data:
            return [val[i] for i in parititions[dp_rank]]
        else:
            return val[dp_rank::dp_size]

    for key in [
        "tokens",
        "total_lengths",
        "response_lengths",
        "rewards",
        "truncated",
        "loss_masks",
        "round_number",
        "sample_indices",
        "rollout_log_probs",
    ]:
        if key not in data:
            continue
        val = get_partition(data[key])
        # move tokens to GPU in advance
        if key == "tokens":
            val = [torch.tensor(t, dtype=torch.long, device=torch.cuda.current_device()) for t in val]
        elif key == "loss_masks":
            val = [torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in val]

        rollout_data[key] = val

    if "rollout_log_probs" in rollout_data:
        from slime.backends.megatron_utils.cp_utils import slice_log_prob_with_cp

        rollout_data["rollout_log_probs"] = [
            torch.tensor(
                slice_log_prob_with_cp(log_prob, total_length, response_length),
                device=torch.cuda.current_device(),
                dtype=torch.bfloat16,  # TODO: hardcode to bf16 at the moment
            )
            for log_prob, total_length, response_length in zip(
                rollout_data["rollout_log_probs"], rollout_data["total_lengths"], rollout_data["response_lengths"]
            )
        ]

    return rollout_data
