import torch

def mp2sigma(mask_prob, mask_schedule_type):
    '''
    mask_prob: [B]
    1-e^(-sigma) = mask_prob
    '''
    # NOTE: when mask_prob = 1.0, sigma and d_sigma will be inf
    mp = mask_prob


    if mask_schedule_type == 'cosine':
        sigma = -torch.log1p(-mp).to(mask_prob.device)
        d_sigma = 0.5 * torch.pi * torch.sqrt((1+mp) / (1-mp))
    elif mask_schedule_type == 'linear_new':
        sigma = -torch.log1p(-mp)
        d_sigma = (1 - (1e-3)) / (1 - mp)

    else:
        raise NotImplementedError

    return sigma, d_sigma

def p2score(pred, xt, mask_id, mp):
    '''
    pred: [B, L, N]
    pt(x_0) / pt(x_t) = pred / pt(x_t|x_0)
    '''
    # 为mask的概率(vocab_size-1) / 自己的概率
    # pred_mask = pred[:, :, -1].unsqueeze(-1).repeat(1, 1, pred.shape[-1])

    # score = pred_mask / pred

    mask = xt == mask_id
    mask = mask.unsqueeze(-1).to(pred.device)
    mask = mask.repeat(1, 1, pred.shape[2])

    mp_reshaped = mp.unsqueeze(-1).unsqueeze(-1).to(pred.device)
    mp_reshaped = mp_reshaped.repeat(1, pred.shape[1], pred.shape[2])

    score = torch.where(mask, pred / mp_reshaped, pred / (1.0 - mp_reshaped))

    return score
        
def score_entropy(score, x, x0, mask_id, sigma):
    '''
    score: [B, L, N]
    x: [B, L],
    x0: [B, L],
    mask_id: int,
    sigma: [B],
    '''
    device = x.device

    sigma = sigma.unsqueeze(-1)

    rel_ind = x == mask_id
    esigm1 = torch.where(
        sigma < 0.5,
        torch.expm1(sigma),
        torch.exp(sigma) - 1
    ).to(device)

    ratio = 1.0 / esigm1.expand_as(x)[rel_ind]
    other_ind = x0[rel_ind]

    # negative_term
    neg_term = ratio * torch.gather(torch.log(score[rel_ind]), -1, other_ind[..., None]).squeeze(-1)

    # positive term
    # pos_term = score[rel_ind][:, :-1].sum(dim=-1)
    pos_term = torch.gather(score[rel_ind], -1, other_ind[..., None]).squeeze(-1)

    # constant term
    const = ratio * (ratio.log() - 1.0)

    entropy = torch.zeros(*x.shape, device=x.device)
    entropy[rel_ind] += pos_term - neg_term + const

    return entropy

def DWDSE_loss(loss, dsigma):
    '''
    weight the loss with dsigma
    '''
    loss = (dsigma[:, None] * loss).sum(dim=-1).mean()

    return loss