from typing import List, Dict, Any

import torch

from gpatch.core.ppo_helper import calculate_grpo_advantages, create_mask
from megatron_datasets.utils import print_rank_0
from gpatch.core.aligner_helper import masked_mean_list



def test(args, rollout_batches: List[Dict[str, List[Any]]]):
    for rb in rollout_batches:
        # assert len(rb) == 8
        for k in rb.keys():
            rb[k] = rb[k][:args.ppo_sampling_keep]


# NOTE: 这里没有经过严格的测试
def best_and_worst(args, rollout_batches: List[Dict[str, List[Any]]]):
    for rb in rollout_batches:
        # assert len(rb) == 8
        rewards = torch.cat(rb["rewards"]).view(-1)
        max_ind = torch.argmax(rewards).item()
        min_ind = torch.argmin(rewards).item()
        for k in rb.keys():
            rb[k] = [rb[k][max_ind], rb[k][min_ind]]

def rand(args, rollout_batches: List[Dict[str, List[Any]]]):
    rmbs = args.ppo_rollout_micro_batch_size
    rrep = args.ppo_sampling_repeat
    rkeep = args.ppo_sampling_keep
    rbs = len(rollout_batches)
    print_rank_0(f"random: {rmbs=} {rrep=} {rkeep=} {rbs=}")
    for rbi, rb in enumerate(rollout_batches):
        # rand_indices = torch.multinomial(torch.ones((rmbs, rrep), dtype=torch.float64), rkeep, replacement=False)
        # row_offset = torch.arange(rmbs, dtype=torch.int64).unsqueeze(1) * rrep
        # rand_indices = (rand_indices + row_offset).view(-1).tolist()
        for k, v in rb.items():
            if v is not None:
                rb[k] = [v[i] for i in range(0, rmbs*rrep, 2)]
                assert len(rb[k]) == rkeep, f"{k=} {len(rb[k])=} {rkeep=}"
            else:
                continue

def pods_var_r(args, rollout_batches: List[Dict[str, List[Any]]]):
    rmbs = args.ppo_rollout_micro_batch_size
    rrep = args.ppo_sampling_repeat
    rkeep = args.ppo_sampling_keep

    rbs = len(rollout_batches)

    total_keep = rbs * rmbs * rkeep
    rollout_batches_rewards = []

    print_rank_0(f"{rbs=} {rmbs=} {rrep=} {rkeep=} {total_keep=}")
    for rbi, rb in enumerate(rollout_batches):
        mask = rb.get('mask', None)
        if mask is None:
            mask = create_mask(
                values=rb['logprobs'],
                prompt_lengths=rb['prompt_lengths'],
                sequence_lengths=rb['sequence_lengths'],
                dtype=rb['rewards'][0].dtype)
            rb['mask'] = mask

        rollout_batches_rewards.extend(rb['rewards'])
    # shape: (rbs, rmbs, rrep)
    rollout_batches_rewards = torch.stack(rollout_batches_rewards, dim=0).view(rbs, rmbs, rrep)
    print_rank_0(f'[PODS] {rollout_batches_rewards.shape=}')

    # 采样策略
    # 对每个 (rbs, rmbs) 下的 rrep 进行 variance-based 择优采样
    rb_ids_list, rmb_ids_list, sample_ids_list = [], [], []

    for rb_idx in range(rbs):
        for rmb_idx in range(rmbs):
            reward_seq = rollout_batches_rewards[rb_idx, rmb_idx]  # shape: (rrep,)
            # 排序
            sorted_rewards, sorted_indices = torch.sort(reward_seq, descending=False)  # 升序

            best_var = -float('inf')
            best_indices = None

            unique_values = torch.unique(reward_seq)
            if len(unique_values) <= 2:
                # 二值reward简化处理
                half = rkeep // 2
                m = half
                n = rkeep - m
                indices = list(range(m)) + list(range(rrep - n, rrep))
                best_indices = sorted_indices[indices]
                # 直接跳过后面的遍历
                # best_var可以直接计算
                subset = reward_seq[best_indices]
                best_var = torch.var(subset, unbiased=False).item()
            else:
                for m in range(rkeep + 1):
                    n = rkeep - m
                    if m > rrep or n > rrep or m + n > rrep:
                        continue
                    indices = []
                    if m > 0:
                        indices += list(range(m))  # 前m个（最小）
                    if n > 0:
                        indices += list(range(rrep - n, rrep))  # 后n个（最大）
                    indices = sorted(set(indices))
                    if len(indices) != rkeep:
                        continue
                    chosen = sorted_indices[indices]
                    subset = reward_seq[chosen]
                    var = torch.var(subset, unbiased=False).item()
                    if var > best_var:
                        best_var = var
                        best_indices = chosen

            if best_indices is None:
                raise ValueError(f"Cannot find valid indices for (rb={rb_idx}, rmb={rmb_idx})")
            rb_ids_list.extend([rb_idx] * rkeep)
            rmb_ids_list.extend([rmb_idx] * rkeep)
            sample_ids_list.extend(best_indices.tolist())

    # 转为tensor，可与原pods一致格式
    rb_ids = torch.tensor(rb_ids_list, dtype=torch.long)
    rmb_ids = torch.tensor(rmb_ids_list, dtype=torch.long)
    sample_ids = torch.tensor(sample_ids_list, dtype=torch.long)

    # List[Dict[str, List[Any]]]
    data=[{k: [None for _ in range(rkeep*rmbs)] for k in rollout_batches[rb_id].keys()} for rb_id in range(rbs)]

    for count, (rb_id, rmb_id, sample_id) in enumerate(zip(rb_ids, rmb_ids, sample_ids)):
        target_rb_id = count // (rmbs * rkeep)
        target_rmbs_id = count % (rmbs * rkeep)
        for k in data[target_rb_id].keys():
            data[target_rb_id][k][target_rmbs_id] = rollout_batches[rb_id][k][rmb_id * rrep + sample_id]

    # re-calculate advantages
    if args.ppo_sampling_keeping_strategy_pods_cross_batch:
        # cross batch 不支持重算优势
        raise ValueError(f"[PODS] Cannot re-calculate advantages for cross batch sampling")
    else:
        for rb in rollout_batches:
            mask = create_mask(
                values=rb['logprobs'],
                prompt_lengths=rb['prompt_lengths'],
                sequence_lengths=rb['sequence_lengths'],
                dtype=rb['rewards'][0].dtype)
            rb['mask'] = mask
            advantages, returns = calculate_grpo_advantages(
                rewards=rb['rewards'],
                mask=mask,
                grpo_sampling_times=rkeep,
                grpo_advantage_epsilon=1e-6
            )
            rb["returns"] = returns
            rb["advantages"] = advantages

    # Careful: clear the original rollout_batches
    rollout_batches.clear()
    rollout_batches.extend(data)

def pods_var_A(args, rollout_batches: List[Dict[str, List[Any]]]):
    rmbs = args.ppo_rollout_micro_batch_size
    rrep = args.ppo_sampling_repeat
    rkeep = args.ppo_sampling_keep

    rbs = len(rollout_batches)

    total_keep = rbs * rmbs * rkeep
    rollout_batches_advantages = []

    print_rank_0(f"{rbs=} {rmbs=} {rrep=} {rkeep=} {total_keep=}")
    for rbi, rb in enumerate(rollout_batches):
        mask = rb.get('mask', None)
        if mask is None:
            mask = create_mask(
                values=rb['logprobs'],
                prompt_lengths=rb['prompt_lengths'],
                sequence_lengths=rb['sequence_lengths'],
                dtype=rb['rewards'][0].dtype)
            rb['mask'] = mask

        # List[torch.Tensor]
        try:
            advantages = rb['advantages']
            returns = rb['returns']
        except KeyError:
            advantages, returns = calculate_grpo_advantages(
                rewards=rb['rewards'],
                mask=mask,
                grpo_sampling_times=rrep,
                grpo_advantage_epsilon=1e-6
            )
            assert advantages[0].dtype == torch.float32
            # carry token-level features with rb
            rb["returns"] = returns
            rb["advantages"] = advantages

        # Use sample-level advantages to select samples
        sample_advantages = masked_mean_list(
            values=advantages,
            mask=mask,
            dim=-1
        ).view(rmbs, rrep)
        # should be sample-level advantages
        rollout_batches_advantages.append(sample_advantages)
    # shape: (rbs, rmbs, rrep)
    rollout_batches_advantages = torch.stack(rollout_batches_advantages, dim=0)
    print_rank_0(f'[PODS] {rollout_batches_advantages.shape=}')

    # 采样策略
    if args.ppo_sampling_keeping_strategy_pods_cross_batch:
        # --- 跨组头尾采样 ---
        flat_adv = rollout_batches_advantages.flatten()  # shape: (rbs*rmbs*rrep,)
        sorted_adv, sorted_indices = torch.sort(flat_adv, descending=False)

        best_var = -float('inf')
        best_indices = None
        for m in range(total_keep + 1):
            n = total_keep - m
            if m > len(sorted_adv) or n > len(sorted_adv):
                continue
            indices = []
            if m > 0:
                indices += list(range(m))  # 最小的m个
            if n > 0:
                indices += list(range(len(sorted_adv) - n, len(sorted_adv)))  # 最大的n个
            indices = sorted(set(indices))
            if len(indices) != total_keep:
                continue
            chosen = sorted_indices[indices]
            subset = flat_adv[chosen]
            var = torch.var(subset, unbiased=False).item()
            if var > best_var:
                best_var = var
                best_indices = chosen

        if best_indices is None:
            raise ValueError("Cannot find valid indices for cross-batch head-tail var-max")
        # 还原为 (rb_idx, rmb_idx, sample_idx)
        rb_ids, rmb_ids, sample_ids = torch.unravel_index(best_indices, (rbs, rmbs, rrep))

    else:
        # 对每个 (rbs, rmbs) 下的 rrep 进行 variance-based 择优采样
        rb_ids_list, rmb_ids_list, sample_ids_list = [], [], []
        for rb_idx in range(rbs):
            for rmb_idx in range(rmbs):
                reward_seq = rollout_batches_advantages[rb_idx, rmb_idx]  # shape: (rrep,)
                # 排序
                sorted_rewards, sorted_indices = torch.sort(reward_seq, descending=False)  # 升序

                best_var = -float('inf')
                best_indices = None

                unique_values = torch.unique(reward_seq)
                if len(unique_values) <= 2:
                    # 二值reward简化处理
                    half = rkeep // 2
                    m = half
                    n = rkeep - m
                    indices = list(range(m)) + list(range(rrep - n, rrep))
                    best_indices = sorted_indices[indices]
                    # 直接跳过后面的遍历
                    # best_var可以直接计算
                    subset = reward_seq[best_indices]
                    best_var = torch.var(subset, unbiased=False).item()
                else:
                    for m in range(rkeep + 1):
                        n = rkeep - m
                        if m > rrep or n > rrep or m + n > rrep:
                            continue
                        indices = []
                        if m > 0:
                            indices += list(range(m))  # 前m个（最小）
                        if n > 0:
                            indices += list(range(rrep - n, rrep))  # 后n个（最大）
                        indices = sorted(set(indices))
                        if len(indices) != rkeep:
                            continue
                        chosen = sorted_indices[indices]
                        subset = reward_seq[chosen]
                        var = torch.var(subset, unbiased=False).item()
                        if var > best_var:
                            best_var = var
                            best_indices = chosen

                if best_indices is None:
                    raise ValueError(f"Cannot find valid indices for (rb={rb_idx}, rmb={rmb_idx})")
                rb_ids_list.extend([rb_idx] * rkeep)
                rmb_ids_list.extend([rmb_idx] * rkeep)
                sample_ids_list.extend(best_indices.tolist())

        # 转为tensor，可与原pods一致格式
        rb_ids = torch.tensor(rb_ids_list, dtype=torch.long)
        rmb_ids = torch.tensor(rmb_ids_list, dtype=torch.long)
        sample_ids = torch.tensor(sample_ids_list, dtype=torch.long)

    # List[Dict[str, List[Any]]]
    data=[{k: [None for _ in range(rkeep*rmbs)] for k in rollout_batches[rb_id].keys()} for rb_id in range(rbs)]

    for count, (rb_id, rmb_id, sample_id) in enumerate(zip(rb_ids, rmb_ids, sample_ids)):
        target_rb_id = count // (rmbs * rkeep)
        target_rmbs_id = count % (rmbs * rkeep)
        for k in data[target_rb_id].keys():
            data[target_rb_id][k][target_rmbs_id] = rollout_batches[rb_id][k][rmb_id * rrep + sample_id]

    # re-calculate advantages
    # if args.ppo_sampling_keeping_strategy_pods_cross_batch:
    #     # cross batch 不支持重算优势
    #     pass
    # else:
    #     for rb in rollout_batches:
    #         mask = create_mask(
    #             values=rb['logprobs'],
    #             prompt_lengths=rb['prompt_lengths'],
    #             sequence_lengths=rb['sequence_lengths'],
    #             dtype=rb['rewards'][0].dtype)
    #         rb['mask'] = mask
    #         advantages, returns = calculate_grpo_advantages(
    #             rewards=rb['rewards'],
    #             mask=mask,
    #             grpo_sampling_times=rkeep,
    #             grpo_advantage_epsilon=1e-6
    #         )
    #         rb["returns"] = returns
    #         rb["advantages"] = advantages

    # Careful: clear the original rollout_batches
    rollout_batches.clear()
    rollout_batches.extend(data)


def pods(args, rollout_batches: List[Dict[str, List[Any]]]):
    rmbs = args.ppo_rollout_micro_batch_size
    rrep = args.ppo_sampling_repeat
    rkeep = args.ppo_sampling_keep

    rbs = len(rollout_batches)

    total_keep = rbs * rmbs * rkeep
    rollout_batches_advantages = []

    print_rank_0(f"{rbs=} {rmbs=} {rrep=} {rkeep=} {total_keep=}")

    for rbi, rb in enumerate(rollout_batches):
        try:
            mask = rb['mask']
        except KeyError:
            mask = create_mask(
                values=rb['logprobs'],
                prompt_lengths=rb['prompt_lengths'],
                sequence_lengths=rb['sequence_lengths'],
                dtype=rb['rewards'][0].dtype)
            # carry token-level features with rb
            rb['mask'] = mask
            
        # List[torch.Tensor]
        try:
            advantages = rb['advantages']
            returns = rb['returns']
        except KeyError:
            advantages, returns = calculate_grpo_advantages(
                rewards=rb['rewards'],
                mask=mask,
                grpo_sampling_times=rrep,
                grpo_advantage_epsilon=1e-6
            )
            assert advantages[0].dtype == torch.float32
            # carry token-level features with rb
            rb["returns"] = returns
            rb["advantages"] = advantages

        # Use sample-level advantages to select samples
        sample_advantages = masked_mean_list(
            values=advantages,
            mask=mask,
            dim=-1
        ).view(rmbs, rrep, 1)
        # should be sample-level advantages
        rollout_batches_advantages.append(sample_advantages)
    # shape: (rbs, rmbs, rrep, -1)
    rollout_batches_advantages = torch.stack(rollout_batches_advantages, dim=0)
    print_rank_0(f'[PODS] {rollout_batches_advantages.shape=}')

    if args.ppo_sampling_keeping_strategy_pods_cross_batch:
        # shape: (rbs * rmbs * rrep, -1)
        f_norms = torch.abs(rollout_batches_advantages.flatten(start_dim=0, end_dim=2))
        print_rank_0(f'[PODS] {f_norms.shape=}')
        # shape: (rbs * rmbs * rrep)
        _, top_k_indices = torch.topk(f_norms, k=total_keep, dim=0)
        print_rank_0(f'[PODS] {top_k_indices.shape=} {top_k_indices.flatten().tolist()}')
        # shape: (rbs * rmbs * rrep)
        rb_ids, rmb_ids, sample_ids = torch.unravel_index(top_k_indices, (rbs, rmbs, rrep))

    else:
        # shape: (rbs * rmbs, rrep, -1)
        f_norms = torch.abs(rollout_batches_advantages.flatten(start_dim=0, end_dim=1))
        print_rank_0(f'[PODS] {f_norms.shape=}')
        # shape: (rbs * rmbs, rkeep)
        _, top_k_indices = torch.topk(f_norms, k=rkeep, dim=1)
        print_rank_0(f'[PODS] {top_k_indices.shape=} {top_k_indices.flatten().tolist()}')
        # shape: (rbs * rmbs * rkeep)
        sample_ids = top_k_indices.flatten()
        rb_indices, rmb_indices = torch.meshgrid(
            torch.arange(rbs), 
            torch.arange(rmbs), 
            indexing='ij'
        )
        # shape: (rbs * rmbs * rkeep)
        rb_ids = rb_indices.unsqueeze(2).expand(-1, -1, rkeep).flatten()
        rmb_ids = rmb_indices.unsqueeze(2).expand(-1, -1, rkeep).flatten()
    
    # List[Dict[str, List[Any]]]
    data=[{k: [None for _ in range(rkeep*rmbs)] for k in rollout_batches[rb_id].keys()} for rb_id in range(rbs)]

    for count, (rb_id, rmb_id, sample_id) in enumerate(zip(rb_ids, rmb_ids, sample_ids)):
        target_rb_id = count // (rmbs * rkeep)
        target_rmbs_id = count % (rmbs * rkeep)
        for k in data[target_rb_id].keys():
            data[target_rb_id][k][target_rmbs_id] = rollout_batches[rb_id][k][rmb_id * rrep + sample_id]

    # re-calculate advantages
    if args.ppo_sampling_keeping_strategy_pods_cross_batch:
        # gather all mask and rewards, they are list of torch.Tensor
        global_mask = []
        global_rewards = []
        for rb in rollout_batches:
            global_mask.extend(rb['mask'])
            global_rewards.extend(rb['rewards'])
        global_advantages, global_returns = calculate_grpo_advantages(
            rewards=global_rewards,
            mask=global_mask,
            grpo_sampling_times=total_keep,
            grpo_advantage_epsilon=1e-6
        )
        # map global_advantages and global_returns to data
        for rbi, rb in enumerate(rollout_batches):
            rb['advantages'] = global_advantages[rbi * rmbs * rkeep : (rbi + 1) * rmbs * rkeep]
            rb['returns'] = global_returns[rbi * rmbs * rkeep : (rbi + 1) * rmbs * rkeep]
    else:
        for rb in rollout_batches:
            mask = create_mask(
                values=rb['logprobs'],
                prompt_lengths=rb['prompt_lengths'],
                sequence_lengths=rb['sequence_lengths'],
                dtype=rb['rewards'][0].dtype)
            rb['mask'] = mask
            advantages, returns = calculate_grpo_advantages(
                rewards=rb['rewards'],
                mask=mask,
                grpo_sampling_times=rkeep,
                grpo_advantage_epsilon=1e-6
            )
            rb["returns"] = returns
            rb["advantages"] = advantages

    # Careful: clear the original rollout_batches
    rollout_batches.clear()
    rollout_batches.extend(data)

def filter_samplings(args, rollout_batches):
    if args.ppo_sampling_keeping_strategy == 'test':
        test(args, rollout_batches)
    elif args.ppo_sampling_keeping_strategy == 'pods':
        pods_var_r(args, rollout_batches)
    elif args.ppo_sampling_keeping_strategy == 'all':
        return rollout_batches
    elif args.ppo_sampling_keeping_strategy == 'd3s':
        pods_var_A(args, rollout_batches)
    elif args.ppo_sampling_keeping_strategy == 'random':
        rand(args, rollout_batches)
    else:
        raise NotImplementedError(f'wtf')
