"""
Ren et al. 2020: SEEKS algorithm for exact best-k selection.
Reference: https://github.com/WenboRen/Topk-Ranking-from-Pairwise-Comparisons
"""
import numpy as np
import math

PI = math.pi

def compare(p):
    return np.random.random() < p

def distribute_item(i, piv, epsilon, su, sd, delta, sup, smid, sdown, mode):
    tmax = int(math.ceil((2.0 / (epsilon * epsilon)) * math.log(4.0 / delta)))
    w, t, comps = 0, 0, 0
    while t < tmax:
        t += 1
        comps += 1
        if mode ^ compare(piv):
            w += 1
        bt = math.sqrt((0.5 / t) * math.log((PI * PI * t * t) / (3.0 * delta)))
        if w / t - bt > 0.5 + su:
            sup.append(i)
            return comps
        if w / t + bt < 0.5 - sd:
            sdown.append(i)
            return comps
    if w / t > 0.5 + 0.5 * epsilon + su:
        sup.append(i)
    elif w / t < 0.5 - 0.5 * epsilon - sd:
        sdown.append(i)
    else:
        smid.append(i)
    return comps

def epsilon_quick_select(S, P, k, epsilon, delta, mode=False):
    n = len(S)
    if n <= k:
        return list(S), 0  # Base case: all elements are in top-k
    v = S[np.random.randint(n)]
    sup, smid, sdown = [], [v], []
    delta_1 = delta / (n * (n - 1))
    comps = 0
    for i in range(n):
        if S[i] == v:
            continue
        comps += distribute_item(S[i], P[S[i], v], epsilon / 2.0, 0.0, 0.0, delta_1, sup, smid, sdown, mode)
    if len(sup) > k:
        result, c = epsilon_quick_select(sup, P, k, epsilon, (n - 1) * delta / n, mode)
        return result, comps + c
    if len(sup) + len(smid) >= k:
        return sup + smid[:k - len(sup)], comps
    kp = k - len(sup) - len(smid)
    sp, c = epsilon_quick_select(sdown, P, kp, epsilon, (n - 1) * delta / n, mode)
    return sup + smid + sp, comps + c

def tournament_k_selection(S, P, k, epsilon, delta, mode=False):
    t, m = 0, 2 * k
    epsilon_t = 0.25 * epsilon
    R = list(S)
    comps = 0
    while len(R) > k:
        t += 1
        epsilon_t *= 0.8
        delta_t = 6.0 * delta / (PI * PI * t * t)
        groups = []
        for i in R:
            if not groups or len(groups[-1]) == m:
                groups.append([])
            groups[-1].append(i)
        R = []
        for part in groups:
            if len(part) <= k:
                R.extend(part)
            else:
                A, c = epsilon_quick_select(part, P, k, epsilon_t, delta_t / k, mode)
                comps += c
                R.extend(A)
    return R, comps

def seebs(S, P, delta, mode=False):
    epsilon_t = 1.0
    R = list(S)
    t = 1
    comps = 0
    while len(R) > 1:
        epsilon_t *= 0.5
        delta_t = 6.0 * delta / (PI * PI * t * t)
        v_list, c = tournament_k_selection(R, P, 1, epsilon_t / 3.0, 2.0 * delta_t / 3.0, mode)
        comps += c
        v = v_list[-1]
        sup, smid, sdown = [], [v], []
        for i in R:
            if i == v:
                continue
            comps += distribute_item(i, P[i, v], epsilon_t / 3.0, 0.0, epsilon_t / 3.0, delta_t / 3.0, sup, smid, sdown, mode)
        R = sup + smid
        t += 1
    return R[-1], comps

def seeks(S, P, k, delta, mode=False):
    epsilon_t = 1.0
    R_t, S_t = list(S), []
    t, k_t = 1, k
    comps = 0
    while len(S_t) < k and len(S_t) + len(R_t) > k:
        epsilon_t *= 0.5
        delta_t = 6.0 * delta / (PI * PI * t * t)
        A_t, c = tournament_k_selection(R_t, P, k_t, epsilon_t / 3.0, delta_t / 3.0, mode)
        comps += c
        v_list, c = tournament_k_selection(A_t, P, 1, epsilon_t / 3.0, delta_t / 3.0, not mode)
        comps += c
        v_t = v_list[-1]
        sup, smid, sdown = [], [v_t], []
        for i in R_t:
            if i == v_t:
                continue
            comps += distribute_item(i, P[i, v_t], epsilon_t / 3.0, epsilon_t / 3.0, epsilon_t / 3.0, delta_t / (3.0 * (len(R_t) - 1)), sup, smid, sdown, mode)
        S_t.extend(sup)
        R_t = smid
        k_t -= len(sup)
        t += 1
    return S_t + R_t[:k - len(S_t)], comps

def seeks_v2(S, P, k, delta, mode=False):
    epsilon_t = 1.0
    R_t, S_t = list(S), []
    t, k_t = 1, k
    comps = 0
    while len(S_t) < k and len(S_t) + len(R_t) > k:
        epsilon_t *= 0.5
        delta_t = 6.0 * delta / (PI * PI * t * t)
        A_t, c = epsilon_quick_select(R_t, P, k_t, epsilon_t / 3.0, delta_t / 3.0, mode)
        comps += c
        v_list, c = tournament_k_selection(A_t, P, 1, epsilon_t / 3.0, delta_t / 3.0, not mode)
        comps += c
        v_t = v_list[-1]
        sup, smid, sdown = [], [v_t], []
        for i in R_t:
            if i == v_t:
                continue
            comps += distribute_item(i, P[i, v_t], epsilon_t / 3.0, epsilon_t / 3.0, epsilon_t / 3.0, delta_t / (3.0 * (len(R_t) - 1)), sup, smid, sdown, mode)
        S_t.extend(sup)
        R_t = smid
        k_t -= len(sup)
        t += 1
    return S_t + R_t[:k - len(S_t)], comps
