import torch

from . import logminusexp
from .theory import (
    _K_Seq_theoretical_log_rho,
    _K_Seq_theoretical_log_beta,
    _K_Seq_theoretical_left_f,
)


def accept_one_of(drafts, targettoken):
    return (drafts == targettoken.unsqueeze(-1)).any(dim=-1).int()


# with replacement


def _wr_recursive_verify_step(log_q, log_p, draft):
    """
    input:
    log_q - shape (batch_size, vocab_size)
    log_p - shape (batch_size, vocab_size)
    draft - shape (batch_size)
    output:
    mask - shape (batch_size)
    log_p_res - shape (batch_size, vocab_size)
    """
    accept_prob = torch.exp(
        log_p.gather(1, draft.unsqueeze(-1)) - log_q.gather(1, draft.unsqueeze(-1))
    ).squeeze(-1)
    logits_p_res = logminusexp(log_p, log_q)
    logits_p_res[
        torch.isinf(logits_p_res).all(dim=-1).unsqueeze(-1).expand_as(logits_p_res)
    ] = 0.0
    log_p_res = torch.log_softmax(logits_p_res, dim=-1)
    return torch.rand_like(accept_prob) < accept_prob, log_p_res


def wr_recursive_verify(logits_q, logits_p, temperature, drafts):
    """
    input:
    logits_q - shape (batch_size, vocab_size)
    logits_p - shape (batch_size, vocab_size)
    temperature - scalar
    drafts - shape (batch_size, num_candidates)
    output:
    targettoken - shape (batch_size)
    """
    num_candidates = drafts.size(1)
    if temperature == 0:
        return torch.argmax(logits_p, dim=-1)

    log_p = torch.log_softmax(logits_p / temperature, dim=-1)
    log_q = torch.log_softmax(logits_q / temperature, dim=-1)

    accepted = torch.zeros(drafts.size(0), dtype=torch.bool, device=drafts.device)
    final_draft = torch.zeros(drafts.size(0), dtype=drafts.dtype, device=drafts.device)
    for i in range(num_candidates):
        mask, log_p = _wr_recursive_verify_step(log_q, log_p, drafts[:, i])
        mask = mask & ~accepted
        final_draft[mask] = drafts[mask, i]
        accepted = accepted | mask
        if torch.all(accepted):
            break
    if not torch.all(accepted):
        last_token = torch.multinomial(torch.softmax(log_p, dim=-1), 1).squeeze(-1)
        final_draft[~accepted] = last_token[~accepted]
    targettoken = final_draft
    return targettoken


# without replacement


def wor_recursive_verify(logits_q, logits_p, temperature, drafts):
    num_candidates = drafts.size(1)
    if temperature == 0:
        return torch.argmax(logits_p, dim=-1)

    log_p = torch.log_softmax(logits_p / temperature, dim=-1)
    log_q = torch.log_softmax(logits_q / temperature, dim=-1)

    accepted = torch.zeros(drafts.size(0), dtype=torch.bool, device=drafts.device)
    final_draft = torch.zeros(drafts.size(0), dtype=drafts.dtype, device=drafts.device)
    for i in range(num_candidates):
        mask, log_p = _wr_recursive_verify_step(log_q, log_p, drafts[:, i])

        log_q.scatter_(-1, drafts[:, i].unsqueeze(-1), -torch.inf)
        log_q = torch.log_softmax(log_q, dim=-1)

        mask = mask & ~accepted
        final_draft[mask] = drafts[mask, i]
        accepted = accepted | mask
        if torch.all(accepted):
            break
    if not torch.all(accepted):
        last_token = torch.multinomial(torch.softmax(log_p, dim=-1), 1).squeeze(-1)
        final_draft[~accepted] = last_token[~accepted]
    targettoken = final_draft
    return targettoken


# K-Seq


def K_Seq_verify(logits_q, logits_p, temperature, drafts):
    num_candidates = drafts.size(1)
    if temperature == 0:
        return torch.argmax(logits_p, dim=-1)

    log_p = torch.nn.functional.log_softmax(logits_p / temperature, dim=-1)
    log_q = torch.nn.functional.log_softmax(logits_q / temperature, dim=-1)

    log_rho = _K_Seq_theoretical_log_rho(log_q, log_p, num_candidates)
    log_beta = _K_Seq_theoretical_log_beta(log_q, log_p, log_rho)
    log_p_acc = _K_Seq_theoretical_left_f(log_beta, num_candidates)

    logits_p_res = logminusexp(log_p, log_q + log_rho.unsqueeze(-1))
    logits_p_res[
        torch.isinf(logits_p_res).all(dim=-1).unsqueeze(-1).expand_as(logits_p_res)
    ] = 0.0
    accept_prob = torch.exp(
        log_p.gather(1, drafts) - log_q.gather(1, drafts) - log_rho.unsqueeze(-1)
    )
    mask = torch.rand_like(accept_prob) < accept_prob
    drafts_cat_last = torch.cat(
        [drafts, torch.multinomial(torch.softmax(logits_p_res, dim=-1), 1)],
        dim=-1,
    )

    mask_cat_true = torch.cat([mask, torch.ones_like(mask[:, :1])], dim=-1)
    indices = torch.argmax(mask_cat_true.to(dtype=torch.int), dim=-1)
    final_draft = drafts_cat_last.gather(1, indices.unsqueeze(-1)).squeeze(-1)
    return final_draft


# GCSpS


def GCSpS_verify(logits_q, logits_p, temperature, drafts):
    num_candidates = drafts.size(1)
    if temperature == 0:
        return torch.argmax(logits_p, dim=-1)

    logits_qn = logits_q.clone()
    if num_candidates > 1:
        topkm1 = torch.topk(logits_q, num_candidates - 1, dim=-1).indices
        logits_qn.scatter_(-1, topkm1, -torch.inf)
    log_qn = torch.nn.functional.log_softmax(logits_qn / temperature, dim=-1)
    log_p = torch.nn.functional.log_softmax(logits_p / temperature, dim=-1)
    mask, log_p = _wr_recursive_verify_step(log_qn, log_p, drafts[:, -1])
    final_draft = drafts[:, -1].clone()
    last_token = torch.multinomial(torch.softmax(log_p, dim=-1), 1).squeeze(-1)
    final_draft[~mask] = last_token[~mask]
    return final_draft
