"""PKPO implementation and numerical demonstration.

Run this file as main with `python pkpo_supplementary.py` to verify the PKPO
algorithm vs a brute force method. If successful, it will print "okay".

Pass@K Policy Optimization: Solving Harder Reinforcement Learning Problems
Christian Walder, Deep Tejas Karkhanis
Google DeepMind
NeuriPS 2025
"""

import itertools
from typing import Callable

import numpy as np
import scipy.special as ss


def _m_normed(N: int, K: int, i: int, j: int) -> float:
  if i == j and i >= K-1:
    return (
        K / (N-K+1) *
        np.prod(np.arange(i-K+2, i+1) / np.arange(N-K+2, N+1))
    )
  elif j > i and j >= K-1 and K >= 2:
    return (
        K / (N-K+1) * (K-1) / N *
        np.prod(np.arange(j-K+2, j) / np.arange(N-K+2, N))
    )
  return 0


def _m_diagonal(N: int, K: int) -> np.ndarray:
  return np.array([_m_normed(N, K, i, i) for i in range(N)])


def rho(g: np.ndarray, K: int) -> float:
  """See Equation (3)."""
  return (np.sort(g) * _m_diagonal(len(g), K)).sum()


def _delta(N: int, K: int, i: int) -> float:
  return _m_normed(N, K, i, i+1) - _m_normed(N, K, i+1, i+1)


def _deltas(N: int, K: int) -> np.ndarray:
  return np.array([_delta(N-1, K, i) for i in range(N-2)])


def _sorted_apply(func: Callable) -> Callable:
  def inner(x: np.ndarray, *args, **kwargs) -> np.ndarray:
    i_sort = np.argsort(x)
    func_x = np.zeros_like(x)
    func_x[i_sort] = func(x[i_sort], *args, **kwargs)
    return func_x
  return inner


@_sorted_apply
def s(g: np.ndarray, K: int):
  """See Equation (19)."""
  N = len(g)
  c = g * _m_diagonal(N, K)
  c[:(N-1)] += g[1:] * _deltas(N+1, K)
  return np.cumsum(c[::-1])[::-1]


@_sorted_apply
def _b(g: np.ndarray, K: int) -> np.ndarray:
  N = len(g)
  w = (_m_diagonal(N-1, K) * np.arange(1, N)).astype(float)
  w[1:] += _deltas(N, K) * np.arange(1, N-1)
  c1 = np.array([(w * g[1:]).sum()])
  c2 = (g[:-1] - g[1:]) * w
  return np.cumsum(np.concatenate((c1, c2)))


def sloo(g: np.ndarray, K: int) -> np.ndarray:
  """See Equation (29)."""
  return s(g, K) - _b(g, K) / (len(g) - 1.0)


def sloo_minus_one(g: np.ndarray, K: int) -> np.ndarray:
  """See Equation (33)."""
  N = len(g)
  return s(g, K) -  K / N / (K-1) * _b(g, K-1)


def brute_s_loo_minus_one(g: np.ndarray, k: int) -> np.ndarray:
  """Brute force implementation of s_loo_minus_one."""
  n = len(g)
  s_brute = np.zeros_like(g)
  for inds in map(list, itertools.combinations(range(n), k)):
    # all n choose k subsets
    for i, ind in enumerate(inds):
      # remove i from the subset
      inds_sans_ith_ind = list(inds)
      del inds_sans_ith_ind[i]
      # estimator is max over ind minus the max over inds_sans_ith_ind
      s_brute[ind] += g[inds].max() - g[inds_sans_ith_ind].max()
  return s_brute / ss.binom(n, k)


def test_s_loo_minus_one():
  # compare brute force with fast implementation for random inputs
  for n in range(2, 10):
    for k in range(2, n+1):
      for _ in range(99):
        g = np.random.rand(n)
        s_brute = brute_s_loo_minus_one(g, k)
        s_fast = sloo_minus_one(g, k)
        assert np.allclose(s_brute, s_fast)


if __name__ == '__main__':
  test_s_loo_minus_one()
  print('okay')
