# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import torch


@torch.compile()
def _stage_linear_routing(self, q_rter, router, auxfree_bias, num_expert_active):
    """
    In:  (num_head_per_rank, bath_size, head_size); bfloat16; contiguous
         (num_head_per_rank, head_size, num_expert); float32; contiguous
         (num_head_per_rank, num_expert); float32; contiguous; detached
         int
    Out: (num_head_per_rank, bath_size, num_expert_active); float32; contiguous
         (num_head_per_rank, bath_size, num_expert_active); int64; contiguous; detached
    """
    # Define variables
    num_head_per_rank, head_size, num_expert = router.shape

    # (num_head_per_rank, bath_size, head_size); float32; contiguous
    q_rter = q_rter.to(torch.float32)
    # (num_head_per_rank, bath_size, num_expert); float32; contiguous
    router_values = q_rter @ router

    # (num_head_per_rank, bath_size, num_expert); float32; contiguous
    topk_input = router_values + auxfree_bias.view(num_head_per_rank, 1, num_expert)
    # (num_head_per_rank, bath_size, num_expert_active); int64; contiguous; detached
    expert_assign = torch.topk(
        input=topk_input,
        k=num_expert_active,
        dim=2,
        largest=True,
        sorted=False,
    ).indices.detach()
    del topk_input

    # (num_head_per_rank, bath_size, num_expert_active); float32; contiguous
    router_values = torch.gather(
        input=router_values,
        dim=2,
        index=expert_assign,
    )

    # (num_head_per_rank, bath_size, num_expert_active); float32; contiguous
    # (num_head_per_rank, bath_size, num_expert_active); int64; contiguous; detached
    return router_values, expert_assign
