import torch

def get_samples(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1
    nsamples = 10000
    lprobs = torch.log_softmax(logits, dim=-1)
    distrib = torch.distributions.categorical.Categorical(logits=lprobs)
    samples = distrib.sample([nsamples]).permute([1, 2, 0])
    return samples


def get_likelihood(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1
    labels = labels.unsqueeze(-1) if labels.ndim == logits.ndim - 1 else labels
    lprobs = torch.log_softmax(logits, dim=-1)
    log_likelihood = lprobs.gather(dim=-1, index=labels)
    return log_likelihood.mean(dim=1)


def get_sampling_discrepancy(logits_ref, logits_score, labels):
    assert logits_ref.shape[0] == 1
    assert logits_score.shape[0] == 1
    assert labels.shape[0] == 1
    if logits_ref.size(-1) != logits_score.size(-1):
        # print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
        vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
        logits_ref = logits_ref[:, :, :vocab_size]
        logits_score = logits_score[:, :, :vocab_size]

    samples = get_samples(logits_ref, labels)
    log_likelihood_x = get_likelihood(logits_score, labels)
    log_likelihood_x_tilde = get_likelihood(logits_score, samples)
    miu_tilde = log_likelihood_x_tilde.mean(dim=-1)
    sigma_tilde = log_likelihood_x_tilde.std(dim=-1)
    discrepancy = (log_likelihood_x.squeeze(-1) - miu_tilde) / sigma_tilde
    return discrepancy.item()


def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
    """
    logits_ref   : torch.FloatTensor, shape (1, T, V)
                   Reference model logits

    logits_score : torch.FloatTensor, shape (1, T, V)
                   Scoring model logits

    labels       : torch.LongTensor, shape (1, T)
                   Ground-truth next-token indices
    """
    assert logits_ref.shape[0] == 1
    assert logits_score.shape[0] == 1
    assert labels.shape[0] == 1
    if logits_ref.size(-1) != logits_score.size(-1):
        # print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
        vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
        logits_ref = logits_ref[:, :, :vocab_size]
        logits_score = logits_score[:, :, :vocab_size]

    labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
    # log p_score(v | context)
    # For every position t and token v:
    # lprobs_score[t, v] = log softmax(logits_score[t, v])
    # lprobs_score → FloatTensor (1, T, V)
    lprobs_score = torch.log_softmax(logits_score, dim=-1)
    # p_ref(v | context)
    # probs_ref → FloatTensor (1, T, V)
    probs_ref = torch.softmax(logits_ref, dim=-1)
    # E[log p_score(v)] = sum over V (p_ref(v) * log p_score(v))
    # mean_ref → FloatTensor (1, T)
    mean_ref = (probs_ref * lprobs_score).sum(dim=-1) 
    # Var[log p_score(v)] = E[log p_score(v)^2] - (E[log p_score(v)])^2, where -
    # E[log p_score(v)] = sum over V (p_ref(v) * log p_score(v))
    # E[log p_score(v)^2] = sum over V (p_ref(v) * (log p_score(v))^2)
    # var_ref → FloatTensor (1, T)
    var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
    # Log p_score(labels) - Actual prob of the ground-truth tokens
    # log_likelihood → FloatTensor (1, T)
    log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
    # Discrepancy = (log p_score(labels) - E[log p_score(v)]) / std[log p_score(v)]
    discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).sqrt()
    discrepancy = discrepancy.mean()
    return {
        'discrepancy': discrepancy.item(),
        'log_likelihood': log_likelihood.sum(dim=-1).mean().item()
    }