import torch
import numpy as np

def matching(C):
    """
    Monotone (non-decreasing) many-to-one matching: every a_i is assigned to some b_j,
    j can repeat, and order is preserved (j_1 <= j_2 <= ... <= j_n).
    Args:
        C: (n, m) cost matrix where C[i,j] is cost of matching a_i to b_j.
    Returns:
        total_cost: float
        pairs:      list of (i, j) in increasing i (0-based indices)
        pair_costs: list of C[i, j] aligned with pairs
    """
    n, m = C.shape
    if n == 0: 
        return 0.0, [], []
    if m == 0:
        raise ValueError("Infeasible: |B|=0 but all A must be matched.")

    # dp[i,j] is not stored fully; we keep only the current row and backpointers.
    ptr = torch.empty((n, m), dtype=torch.int32)  # previous j* for dp[i,j]; -1 marks start at i=0

    # Base row (i=0): choose any j
    dp_cur = C[0].clone()

    ptr[0, :] = -1

    for i in range(1, n):
        prev = dp_cur
        # prefix minima over prev[0..j-1] to allow jumps from any k<j
        prefix_min = torch.empty(m, dtype=torch.float64)
        prefix_arg = torch.empty(m, dtype=torch.int32)
        prefix_min[0] = np.inf
        prefix_arg[0] = -1
        best = prev[0]; best_idx = 0
        for j in range(1, m):
            if prev[j-1] < best:
                best = prev[j-1]; best_idx = j-1
            prefix_min[j] = best
            prefix_arg[j] = best_idx

        dp_next = torch.empty(m, dtype=torch.float64)
        for j in range(m):
            stay = prev[j]                 # match a_i to the SAME b_j (many-to-one)
            inc  = prefix_min[j]           # match a_i to a larger b_j (from best k<j)
            if stay <= inc:                # tie-break: prefer staying (fewer jumps)
                dp_next[j] = C[i, j] + stay
                ptr[i, j] = j
            else:
                dp_next[j] = C[i, j] + inc
                ptr[i, j] = prefix_arg[j]
        dp_cur = dp_next

    # Pick best terminal j and backtrack
    j = int(np.argmin(dp_cur))
    total_cost = float(dp_cur[j])

    pairs, pair_costs = [], []
    for i in range(n-1, -1, -1):
        pairs.append((i, j))
        pair_costs.append(float(C[i, j]))
        j = int(ptr[i, j])
        if j == -1 and i > 0:
            # reached start early (should only happen right before i==0)
            pass
    pairs.reverse(); pair_costs.reverse()
    return total_cost, pairs, pair_costs


if __name__ == '__main__':
    # test 
    C = np.array([
        [ 1,  5,  9],   # a0 vs b0,b1,b2
        [ 2,  4,  8],   # a1
        [ 6,  2,  5],   # a2
        [ 9,  3,  1],   # a3
        [10,  4,  1],   # a4
    ], dtype=float)
    C = torch.from_numpy(C)
    print(matching(C))
    breakpoint()
