import torch

from . import logminusexp

# with replacement


@torch.no_grad()
def wr_optimal_theory(logits_q, logits_p, temperature, num_candidates):
    """
    input:
    q - shape (batch_size, vocab_size)
    p - shape (batch_size, vocab_size)
    return:
    optimal_alpha - shape (batch_size)
    """
    if temperature == 0:
        p_pick = torch.argmax(logits_p, dim=-1)
        q_pick = torch.argmax(logits_q, dim=-1)
        f_best_value = (p_pick == q_pick).float()
        optimal_alpha = f_best_value
        return optimal_alpha
    _, indices = torch.sort(logits_q - logits_p, dim=-1, descending=True)
    sorted_logits_p = torch.gather(logits_p, 1, indices) / temperature
    sorted_logits_q = torch.gather(logits_q, 1, indices) / temperature
    _log_cumsum_p = torch.logcumsumexp(sorted_logits_p, dim=-1)
    log_cumsum_p = _log_cumsum_p - _log_cumsum_p[..., -1].unsqueeze(-1)
    _log_cumsum_q = torch.logcumsumexp(sorted_logits_q, dim=-1)
    log_cumsum_q = _log_cumsum_q - _log_cumsum_q[..., -1].unsqueeze(-1)

    log_nfm1_best_value = torch.max(
        logminusexp(num_candidates * log_cumsum_q, log_cumsum_p), dim=-1
    ).values
    nfm1_best_value = torch.exp(log_nfm1_best_value)
    optimal_alpha = 1 - nfm1_best_value
    return optimal_alpha


# without replacement


def get_log_W(log_q, num_candidates):
    """
    input:
    log_q - shape (batch_size, vocab_size)
    num_candidates - int
    return:
    log_W - shape (batch_size, num_candidates, vocab_size)
    W[..., 0, :] = cumsum(q[..., :])
    W[..., k, m] = W[..., k, m-1] + q[..., m] * W[..., k-1, m-1]
    """
    log_W = torch.zeros(
        log_q.size(0),
        num_candidates,
        log_q.size(1),
        device=log_q.device,
        dtype=log_q.dtype,
    )
    log_W[..., 0, :] = torch.logcumsumexp(log_q, dim=-1)
    for k in range(1, num_candidates):
        t = torch.roll(log_W[..., k - 1, :], shifts=1, dims=-1)
        t[..., 0] = -float("inf")
        log_W[..., k, :] = torch.logcumsumexp(log_q + t, dim=-1)
    return log_W


@torch.no_grad()
def wor_optimal_theory(logits_q, logits_p, temperature, num_candidates):
    if temperature == 0:
        p_pick = torch.argmax(logits_p, dim=-1)
        topk = torch.topk(logits_q, num_candidates, dim=-1).indices
        f_best_value = (p_pick.unsqueeze(-1) == topk).any(dim=-1).float()
        optimal_alpha = f_best_value
        return optimal_alpha
    _, indices = torch.sort(logits_q - logits_p, dim=-1, descending=True)
    sorted_logits_p = torch.gather(logits_p, 1, indices) / temperature
    sorted_logits_q = torch.gather(logits_q, 1, indices) / temperature
    sorted_log_q = torch.nn.functional.log_softmax(sorted_logits_q, dim=-1)
    log_W = get_log_W(sorted_log_q, num_candidates)
    log_Q = log_W - log_W[..., -1].unsqueeze(-1)
    log_Q = log_Q[..., -1, :]
    _log_cumsum_p = torch.logcumsumexp(sorted_logits_p, dim=-1)
    log_cumsum_p = _log_cumsum_p - _log_cumsum_p[..., -1].unsqueeze(-1)
    log_nfm1_best_value = torch.max(logminusexp(log_Q, log_cumsum_p), dim=-1).values
    nfm1_best_value = torch.exp(log_nfm1_best_value)
    optimal_alpha = 1 - nfm1_best_value
    return optimal_alpha


# K-Seq theoretical analysis


def _K_Seq_theoretical_log_beta(log_q, log_p, log_rho):
    """
    input:
    log_q - shape (batch_size, vocab_size)
    log_p - shape (batch_size, vocab_size)
    log_rho - shape (batch_size)
    return:
    log_beta - shape (batch_size)
    """
    log_beta = torch.logsumexp(torch.min(log_q, log_p - log_rho.unsqueeze(-1)), dim=-1)
    return log_beta


def _K_Seq_theoretical_left_f(log_beta, num_candidates):
    """
    input:
    log_beta - shape (batch_size)
    return:
    left_f - shape (batch_size)
    """
    return torch.log(-torch.expm1(num_candidates * torch.log1p(-torch.exp(log_beta))))


def _K_Seq_theoretical_right_f(log_rho, log_beta):
    return log_rho + log_beta


def _K_Seq_theoretical_log_rho(log_q, log_p, num_candidates, max_iter=20, tol=1e-6):
    """
    input:
    log_q - shape (batch_size, vocab_size)
    log_p - shape (batch_size, vocab_size
    num_candidates - int
    return:
    log_rho - shape (batch_size)
    """
    log_rho_left = torch.zeros(log_q.size(0), device=log_q.device, dtype=log_q.dtype)
    import math

    log_rho_right_1 = torch.full_like(log_rho_left, math.log(num_candidates))
    log_rho_right_2 = torch.max(log_q - log_p, dim=-1).values
    log_rho_right = torch.min(log_rho_right_1, log_rho_right_2)
    log_rho_mid = (log_rho_left + log_rho_right) / 2

    # binary search. left >=0, right <= 0
    for _ in range(max_iter):
        log_beta = _K_Seq_theoretical_log_beta(log_q, log_p, log_rho_mid)
        left_f = _K_Seq_theoretical_left_f(log_beta, num_candidates)
        right_f = _K_Seq_theoretical_right_f(log_rho_mid, log_beta)
        f_pos = left_f >= right_f
        log_rho_left[f_pos] = log_rho_mid[f_pos]
        log_rho_right[~f_pos] = log_rho_mid[~f_pos]
        log_rho_mid = (log_rho_left + log_rho_right) / 2
        if torch.max(torch.abs(log_rho_right - log_rho_left)) < tol:
            break

    return log_rho_mid


@torch.no_grad()
def K_Seq_theory(logits_q, logits_p, temperature, num_candidates):
    """
    input:
    logits_q - shape (batch_size, vocab_size)
    logits_p - shape (batch_size, vocab_size)
    temperature - scalar
    num_candidates - int
    return:
    theoretical_alpha - shape (batch_size)
    """
    if temperature == 0:
        p_pick = torch.argmax(logits_p, dim=-1)
        q_pick = torch.argmax(logits_q, dim=-1)
        f_best_value = (p_pick == q_pick).float()
        optimal_alpha = f_best_value
        return optimal_alpha
    log_p = torch.nn.functional.log_softmax(logits_p / temperature, dim=-1)
    log_q = torch.nn.functional.log_softmax(logits_q / temperature, dim=-1)

    log_rho = _K_Seq_theoretical_log_rho(log_q, log_p, num_candidates)
    log_beta = _K_Seq_theoretical_log_beta(log_q, log_p, log_rho)
    log_p_acc = _K_Seq_theoretical_left_f(log_beta, num_candidates)
    theoretical_alpha = torch.exp(log_p_acc)
    return theoretical_alpha


# GCSpS theoretical analysis


@torch.no_grad()
def GCSpS_theory(logits_q, logits_p, temperature, num_candidates):
    """
    input:
    logits_q - shape (batch_size, vocab_size)
    logits_p - shape (batch_size, vocab_size)
    temperature - scalar
    num_candidates - int
    return:
    theoretical_alpha - shape (batch_size)
    """
    if temperature == 0:
        p_pick = torch.argmax(logits_p, dim=-1)
        topk = torch.topk(logits_q, num_candidates, dim=-1).indices
        f_best_value = (p_pick.unsqueeze(-1) == topk).any(dim=-1).float()
        optimal_alpha = f_best_value
        return optimal_alpha
    sorted_logits_q, indices = torch.sort(logits_q, descending=True, dim=-1)
    sorted_logits_qn = sorted_logits_q
    sorted_logits_qn[:, : num_candidates - 1] = -torch.inf
    sorted_log_qn = torch.nn.functional.log_softmax(
        sorted_logits_qn / temperature, dim=-1
    )
    sorted_log_qn[:, : num_candidates - 1] = 0.0

    sorted_logits_p = torch.gather(logits_p, 1, indices)
    sorted_log_p = torch.nn.functional.log_softmax(
        sorted_logits_p / temperature, dim=-1
    )

    theoretical_alpha = torch.exp(torch.minimum(sorted_log_p, sorted_log_qn)).sum(-1)
    return theoretical_alpha
