import torch
import torch.nn as nn
import torch.nn.functional as F

def top_pct_mask_per_seq_with_valid(logp, valid_mask, pct=0.20, at_least_one=True):
    # Shapes: [B, T]
    assert logp.shape == valid_mask.shape
    # logp = logp.float()
    # valid_mask = valid_mask.bool()

    valid_len = valid_mask.sum(-1)
    k = (pct * valid_len.float()).ceil().long()
    if at_least_one:
        k = torch.where(valid_len > 0, torch.clamp(k, min=1), torch.zeros_like(k))

    x = logp.masked_fill(~valid_mask, float('-inf'))    # ignore pads
    rank = x.argsort(dim=-1, descending=True).argsort(dim=-1)  # rank 0 is largest
    out = (rank < k.unsqueeze(-1)) & valid_mask
    return out


def get_loss(model, ref_model, inputs, loss_type, beta=0.1, attention_temp=2.0, identification=False):

    # forget_loss
    if 'GA' in loss_type:
        if not identification:
            forget_loss = ga_loss(model, inputs)
        else:
            forget_loss = ga_loss_iden(model, inputs, ref_model, attention_temp=attention_temp)
        
    elif 'NPO' in loss_type:
        if not identification:
            forget_loss = npo_loss(model, ref_model, inputs, beta=beta)
        else:
            forget_loss = npo_loss_iden(model, ref_model, inputs, beta=beta, attention_temp=attention_temp)
    elif 'ME' in loss_type:
        if not identification:
            forget_loss = me_loss(model, inputs)
        else:
            forget_loss = me_loss_iden(model, inputs, ref_model, attention_temp=attention_temp)
        
    elif 'attn' in loss_type:
        if not identification:
            forget_loss = attn_loss(model, ref_model, inputs, attention_temp=attention_temp)
        else:
            forget_loss = attn_loss_iden(model, ref_model, inputs, attention_temp=attention_temp)
        
    elif 'DPO' in loss_type:
        forget_loss = dpo_loss(model, ref_model, inputs, beta=beta)
    elif 'IDK' in loss_type:
        forget_loss = idk_loss(model, inputs)
    else:
        forget_loss = 0


    # regularization_loss
    if 'attn' in loss_type:
        if 'KL' in loss_type:
            regularization_loss = kl_loss_attn(model, ref_model, inputs)
        else:
            regularization_loss = gd_loss(model, inputs)
            
        print(f"our forget loss: {forget_loss.item()}        our retain loss: {regularization_loss.item()}")

    elif 'GD' in loss_type:
        regularization_loss = gd_loss(model, inputs)
    elif 'KL' in loss_type:
        regularization_loss = kl_loss(model, ref_model, inputs)
    elif 'AP' in loss_type:
        regularization_loss = ap_loss(model, inputs, beta=beta)
    else:
        regularization_loss = 0
    if loss_type == 'LLMU':
        forget_loss = ga_loss(model, inputs)
        regularization_loss = mismatch_loss(model, inputs) + kl_loss(model, ref_model, inputs)

    return forget_loss, regularization_loss



def ga_loss_iden(model, inputs, ref_model, attention_temp):
    forget_inputs = inputs[0]
    input_ids, labels, attention_mask = forget_inputs
    outputs = model(input_ids, labels=None, attention_mask=attention_mask)

    log_probs = F.log_softmax(outputs.logits[..., :-1, :],dim=-1)
    expanded_input_ids = input_ids[:,1:].clone().unsqueeze(-1)
    token_log_probs = log_probs.gather(dim=-1, index=expanded_input_ids).squeeze(-1)

    with torch.no_grad():
        ref_outputs = ref_model(input_ids, labels=None, attention_mask=attention_mask)
        ref_log_probs = F.log_softmax(ref_outputs.logits[..., :-1, :],dim=-1)
        ref_token_log_probs = ref_log_probs.gather(dim=-1, index=expanded_input_ids).squeeze(-1)

        ref_outputs_temp = ref_model(input_ids, labels=None, attention_mask=attention_mask,attention_temp=attention_temp)
        ref_log_probs_temp = F.log_softmax(ref_outputs_temp.logits[..., :-1, :],dim=-1)
        ref_token_log_probs_temp = ref_log_probs_temp.gather(dim=-1, index=expanded_input_ids).squeeze(-1)
    
    mask = (labels[:,1:].clone()!=-100)
    new_mask = top_pct_mask_per_seq_with_valid(ref_token_log_probs-ref_token_log_probs_temp, mask, pct=0.2)

    loss = ((new_mask * token_log_probs).sum(-1)/mask.sum(-1)).mean()

    return loss


def npo_loss_iden(model, ref_model, inputs, beta=0.1, attention_temp=2.0):


    forget_inputs = inputs[0]
    input_ids, labels, attention_mask = forget_inputs
    outputs = model(input_ids, labels=None, attention_mask=attention_mask)

    log_probs = F.log_softmax(outputs.logits[..., :-1, :],dim=-1)
    expanded_input_ids = input_ids[:,1:].clone().unsqueeze(-1)
    token_log_probs = log_probs.gather(dim=-1, index=expanded_input_ids).squeeze(-1)

    with torch.no_grad():
        ref_outputs = ref_model(input_ids, labels=None, attention_mask=attention_mask)
        ref_log_probs = F.log_softmax(ref_outputs.logits[..., :-1, :],dim=-1)
        ref_token_log_probs = ref_log_probs.gather(dim=-1, index=expanded_input_ids).squeeze(-1)

        ref_outputs_temp = ref_model(input_ids, labels=None, attention_mask=attention_mask,attention_temp=attention_temp)
        ref_log_probs_temp = F.log_softmax(ref_outputs_temp.logits[..., :-1, :],dim=-1)
        ref_token_log_probs_temp = ref_log_probs_temp.gather(dim=-1, index=expanded_input_ids).squeeze(-1)
    
    mask = (labels[:,1:].clone()!=-100)
    new_mask = top_pct_mask_per_seq_with_valid(ref_token_log_probs-ref_token_log_probs_temp, mask, pct=0.2)

    # loss = ((new_mask * token_log_probs).sum(-1)/mask.sum(-1)).mean()

    loss_current = (token_log_probs * new_mask).sum(-1)
    loss_ref = (ref_token_log_probs * new_mask).sum(-1)

    # forget_inputs = inputs[0]
    # input_ids, labels, attention_mask = forget_inputs

    # outputs = model(input_ids, labels=labels,
    #                 attention_mask=attention_mask)
    # loss_current = get_batch_loss(outputs.logits, labels)

    

    # with torch.no_grad():
    #     ref_outputs = ref_model(input_ids, labels=labels,
    #                             attention_mask=attention_mask)
    #     loss_ref = get_batch_loss(ref_outputs.logits, labels)

    neg_log_ratios = loss_current - loss_ref
    loss = - F.logsigmoid(beta * neg_log_ratios).mean() * 2 / beta

    return loss


# def get_batch_loss(logits, labels):
#     shifted_labels = labels[..., 1:].contiguous()
#     logits = logits[..., :-1, :].contiguous()
#     loss_function = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
#     # get the sum loss for each sequence in a batch
#     loss = loss_function(logits.transpose(-1, -2), shifted_labels).sum(dim=-1)
#     return loss


def me_loss_iden(model, inputs, ref_model, attention_temp):
    forget_inputs = inputs[0]
    input_ids, labels, attention_mask = forget_inputs
    outputs = model(input_ids, labels=None, attention_mask=attention_mask)
    logits = outputs.logits

    with torch.no_grad():
        expanded_input_ids = input_ids[:,1:].clone().unsqueeze(-1)
        ref_outputs = ref_model(input_ids, labels=None, attention_mask=attention_mask)
        ref_log_probs = F.log_softmax(ref_outputs.logits[..., :-1, :],dim=-1)
        ref_token_log_probs = ref_log_probs.gather(dim=-1, index=expanded_input_ids).squeeze(-1)

        ref_outputs_temp = ref_model(input_ids, labels=None, attention_mask=attention_mask,attention_temp=attention_temp)
        ref_log_probs_temp = F.log_softmax(ref_outputs_temp.logits[..., :-1, :],dim=-1)
        ref_token_log_probs_temp = ref_log_probs_temp.gather(dim=-1, index=expanded_input_ids).squeeze(-1)
    
    mask = (labels[:,1:].clone()!=-100)
    new_mask = top_pct_mask_per_seq_with_valid(ref_token_log_probs-ref_token_log_probs_temp, mask, pct=0.2)

    num_labels = logits.shape[-1]

    assert logits.shape[:-1] == labels.shape, "Logits and labels must have compatible shapes."

    # Adjust logits and labels to exclude the last token
    labels = labels[:, 1:].clone()  # (bs, seq_len - 1)
    logits = logits[:, :-1, :]  # (bs, seq_len - 1, vocab_size)

    soft_outputs = F.softmax(logits, dim=-1).view(-1, num_labels)  # (bs*seq_len, vocab_size)
    uniform_dist = torch.full_like(soft_outputs, 1.0 / num_labels).to(logits.device)  # (bs*seq_len, vocab_size)

    loss_mask = (labels != -100).view(-1)  # (bs*(seq_len - 1))

    kl_div = F.kl_div((soft_outputs + 1e-12).log(), uniform_dist, reduction='none').sum(-1)  # (bs*(seq_len - 1))

    masked_kl_div = kl_div * new_mask  # (bs*(seq_len - 1))
    loss = masked_kl_div.sum() / loss_mask.sum()

    return loss



def attn_loss_iden(model, ref_model, inputs, attention_temp):
    forget_inputs = inputs[0]
    input_ids, labels, attention_mask = forget_inputs
    # print(labels)
    outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
    base_loss = outputs.loss.item()
    log_probs = F.log_softmax(outputs.logits[:, :-1, :], dim=-1)


    with torch.no_grad():
        expanded_input_ids = input_ids[:,1:].clone().unsqueeze(-1)
        ref_outputs = ref_model(input_ids, labels=labels, attention_mask=attention_mask)
        ref_log_probs = F.log_softmax(ref_outputs.logits[..., :-1, :],dim=-1)
        ref_token_log_probs = ref_log_probs.gather(dim=-1, index=expanded_input_ids).squeeze(-1)
        ref_loss = ref_outputs.loss.item()

        ref_outputs_temp = ref_model(input_ids, labels=labels, attention_mask=attention_mask,attention_temp=attention_temp)
        ref_log_probs_temp = F.log_softmax(ref_outputs_temp.logits[..., :-1, :],dim=-1)
        ref_token_log_probs_temp = ref_log_probs_temp.gather(dim=-1, index=expanded_input_ids).squeeze(-1)
        ref_loss_temp = ref_outputs_temp.loss.item()
    
    mask = (labels[:,1:].clone()!=-100)
    new_mask = top_pct_mask_per_seq_with_valid(ref_token_log_probs-ref_token_log_probs_temp, mask, pct=0.2)

    print(f"Forget loss: {base_loss}                Ref Forget loss: {ref_loss_temp}            Original Forget loss: {ref_loss}")

    num_labels = log_probs.shape[-1]

    # # Adjust logits and labels to exclude the last token
    # new_labels = labels[:, 1:].clone()  # (bs, seq_len - 1)
    # loss_mask = (new_labels != -100)

    assert log_probs.shape[:-1] == new_mask.shape, "Logits and labels must have compatible shapes."


    kl_div = F.kl_div(log_probs, ref_log_probs_temp, reduction='none',log_target=True).sum(-1)  # (bs*(seq_len - 1))

    masked_kl_div = kl_div * new_mask  # (bs*(seq_len - 1))

    loss = (masked_kl_div.sum(-1) / mask.sum(-1)).mean()

    return loss












def attn_loss(model, ref_model, inputs, attention_temp):
    forget_inputs = inputs[0]
    input_ids, labels, attention_mask = forget_inputs
    # print(labels)
    outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
    base_loss = outputs.loss.item()
    log_probs = F.log_softmax(outputs.logits[:, :-1, :], dim=-1)

    with torch.no_grad():
        ref_outputs = ref_model(input_ids, labels=labels, attention_mask=attention_mask, attention_temp=attention_temp)
        ref_log_probs = F.log_softmax(ref_outputs.logits[:, :-1, :], dim=-1)
        ref_loss = ref_outputs.loss.item()

    print(f"Forget loss: {base_loss}                Ref Forget loss: {ref_loss}")

    num_labels = log_probs.shape[-1]

    # Adjust logits and labels to exclude the last token
    new_labels = labels[:, 1:].clone()  # (bs, seq_len - 1)
    loss_mask = (new_labels != -100)

    assert log_probs.shape[:-1] == new_labels.shape, "Logits and labels must have compatible shapes."

    parallel = False
    if parallel:
        log_probs = log_probs.view(-1, num_labels)  # (bs*seq_len, vocab_size)
        ref_log_probs = ref_log_probs.view(-1, num_labels).to(log_probs.device)  # (bs*seq_len, vocab_size)
        loss_mask = loss_mask.view(-1)  # (bs*(seq_len - 1))


    kl_div = F.kl_div(log_probs, ref_log_probs, reduction='none',log_target=True).sum(-1)  # (bs*(seq_len - 1))

    masked_kl_div = kl_div * loss_mask  # (bs*(seq_len - 1))

    if parallel:
        loss = masked_kl_div.sum() / loss_mask.sum()
    else:
        loss = (masked_kl_div.sum(-1) / loss_mask.sum(-1)).mean()

    return loss

# Regularization Loss: KL
def kl_loss_attn(model, ref_model, inputs):
    retain_inputs = inputs[1]
    input_ids, labels, attention_mask = retain_inputs

    # Adjust logits and labels to exclude the last token
    new_labels = labels[:, 1:].clone()  # (bs, seq_len - 1)
    loss_mask = (new_labels != -100)

    outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
    base_loss = outputs.loss.item()
    log_probs = F.log_softmax(outputs.logits[:, :-1, :], dim=-1)

    num_labels = log_probs.shape[-1]

    with torch.no_grad():
        ref_outputs = ref_model(input_ids, labels=labels, attention_mask=attention_mask)
        ref_log_probs = F.log_softmax(ref_outputs.logits[:, :-1, :], dim=-1)
        ref_loss = ref_outputs.loss.item()

    print(f"Retain loss: {base_loss}        Ref Retain loss: {ref_loss}")
    parallel = False
    if parallel:
        log_probs = log_probs.view(-1, num_labels)  # (bs*seq_len, vocab_size)
        ref_log_probs = ref_log_probs.view(-1, num_labels).to(log_probs.device)  # (bs*seq_len, vocab_size)
        loss_mask = loss_mask.view(-1)  # (bs*(seq_len - 1))
        
    kl_div = nn.functional.kl_div(log_probs, ref_log_probs, reduction='none', log_target=True).sum(-1)
    masked_kl_div = kl_div * loss_mask  # (bs*(seq_len - 1))

    if parallel:
        loss = masked_kl_div.sum() / loss_mask.sum()
    else:
        loss = (masked_kl_div.sum(-1) / loss_mask.sum(-1)).mean()

    return loss

# Regularization Loss: GD
def gd_loss(model, inputs):
    retain_inputs = inputs[1]
    input_ids, labels, attention_mask = retain_inputs

    outputs = model(input_ids, labels=labels,
                    attention_mask=attention_mask)
    loss = outputs.loss
    return loss


def ga_loss(model, inputs):
    forget_inputs = inputs[0]
    input_ids, labels, attention_mask = forget_inputs
    outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
    loss = -1 * outputs.loss
    return loss


def npo_loss(model, ref_model, inputs, beta=0.1):
    forget_inputs = inputs[0]
    input_ids, labels, attention_mask = forget_inputs

    outputs = model(input_ids, labels=labels,
                    attention_mask=attention_mask)
    loss_current = get_batch_loss(outputs.logits, labels)

    with torch.no_grad():
        ref_outputs = ref_model(input_ids, labels=labels,
                                attention_mask=attention_mask)
        loss_ref = get_batch_loss(ref_outputs.logits, labels)

    neg_log_ratios = loss_current - loss_ref
    loss = - F.logsigmoid(beta * neg_log_ratios).mean() * 2 / beta

    return loss


def idk_loss(model, inputs):
    forget_idk_inputs = inputs[2]
    input_ids, labels, attention_mask = forget_idk_inputs

    outputs = model(input_ids, labels=labels,
                    attention_mask=attention_mask)
    loss = outputs.loss
    return loss


def dpo_loss(model, ref_model, inputs, beta=0.1):
    forget_inputs, forget_idk_inputs = inputs[0], inputs[2]
    forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
    idk_input_ids, idk_labels, idk_attention_mask = forget_idk_inputs

    idk_outputs = model(idk_input_ids, labels=idk_labels, attention_mask=idk_attention_mask)
    forget_outputs = model(forget_input_ids, labels=forget_labels, attention_mask=forget_attention_mask)
    idk_loss_current = -1 * get_batch_loss(idk_outputs.logits, idk_labels)
    forget_loss_current = -1 * get_batch_loss(forget_outputs.logits, forget_labels)

    with torch.no_grad():
        idk_outputs_ref = ref_model(idk_input_ids, labels=idk_labels, attention_mask=idk_attention_mask)
        forget_outputs_ref = ref_model(forget_input_ids, labels=forget_labels, attention_mask=forget_attention_mask)
        idk_loss_ref = -1 * get_batch_loss(idk_outputs_ref.logits, idk_labels)
        forget_loss_ref = -1 * get_batch_loss(forget_outputs_ref.logits, forget_labels)

    pi_logratios = idk_loss_current - forget_loss_current
    ref_logratios = idk_loss_ref - forget_loss_ref
    loss = - F.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() * 2 / beta
    return loss


# Regularization Loss: AP
def ap_loss(model, inputs, beta=0.1):
    retain_inputs, retain_idk_inputs = inputs[1], inputs[3]
    retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
    retain_idk_input_ids, retain_idk_labels, retain_idk_attention_mask = retain_idk_inputs

    outputs = model(retain_input_ids, labels=retain_labels, attention_mask=retain_attention_mask)
    idk_outputs = model(retain_idk_input_ids, labels=retain_idk_labels, attention_mask=retain_idk_attention_mask)

    loss = get_batch_loss(outputs.logits, retain_labels)
    loss_idk = get_batch_loss(idk_outputs.logits, retain_idk_labels)

    neg_log_ratios = loss_idk - loss

    loss = - F.logsigmoid(beta * neg_log_ratios).mean() / beta

    return loss


# # Regularization Loss: KL
# def kl_loss(model, ref_model, inputs):
#     retain_inputs = inputs[1]
#     input_ids, labels, attention_mask = retain_inputs

#     outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
#     probs = F.log_softmax(outputs.logits, dim=-1).view(-1, outputs.logits.shape[-1])

#     with torch.no_grad():
#         outputs_ref = ref_model(input_ids, labels=labels, attention_mask=attention_mask)
#     ref_probs = F.log_softmax(outputs_ref.logits, dim=-1).view(-1, outputs_ref.logits.shape[-1])

#     loss = nn.functional.kl_div(
#         probs, ref_probs, reduction='batchmean', log_target=True)

#     return loss


# Regularization Loss: KL
def kl_loss(model, ref_model, inputs):
    retain_inputs = inputs[1]
    input_ids, labels, attention_mask = retain_inputs

    # Adjust logits and labels to exclude the last token
    new_labels = labels[:, 1:].clone()  # (bs, seq_len - 1)
    loss_mask = (new_labels != -100)

    outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
    base_loss = outputs.loss.item()
    log_probs = F.log_softmax(outputs.logits[:, :-1, :], dim=-1)

    num_labels = log_probs.shape[-1]

    with torch.no_grad():
        ref_outputs = ref_model(input_ids, labels=labels, attention_mask=attention_mask)
        ref_log_probs = F.log_softmax(ref_outputs.logits[:, :-1, :], dim=-1)
        ref_loss = ref_outputs.loss.item()

    # print(f"Retain loss: {base_loss}        Ref Retain loss: {ref_loss}")
    parallel = False
    if parallel:
        log_probs = log_probs.view(-1, num_labels)  # (bs*seq_len, vocab_size)
        ref_log_probs = ref_log_probs.view(-1, num_labels).to(log_probs.device)  # (bs*seq_len, vocab_size)
        loss_mask = loss_mask.view(-1)  # (bs*(seq_len - 1))
        
    kl_div = nn.functional.kl_div(log_probs, ref_log_probs, reduction='none', log_target=True).sum(-1)
    masked_kl_div = kl_div * loss_mask  # (bs*(seq_len - 1))

    if parallel:
        loss = masked_kl_div.sum() / loss_mask.sum()
    else:
        loss = (masked_kl_div.sum(-1) / loss_mask.sum(-1)).mean()

    return loss


def mismatch_loss(model, inputs):
    mismatch_inputs = inputs[4]
    input_ids, labels, attention_mask = mismatch_inputs

    outputs = model(input_ids, labels=labels,
                    attention_mask=attention_mask)

    loss = outputs.loss
    return loss


# Regularization Loss: GD
def gd_loss(model, inputs):
    retain_inputs = inputs[1]
    input_ids, labels, attention_mask = retain_inputs

    outputs = model(input_ids, labels=labels,
                    attention_mask=attention_mask)
    loss = outputs.loss
    return loss


def get_batch_loss(logits, labels):
    shifted_labels = labels[..., 1:].contiguous()
    logits = logits[..., :-1, :].contiguous()
    loss_function = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
    # get the sum loss for each sequence in a batch
    loss = loss_function(logits.transpose(-1, -2), shifted_labels).sum(dim=-1)
    return loss


def me_loss(model, inputs):
    forget_inputs = inputs[0]
    input_ids, labels, attention_mask = forget_inputs
    outputs = model(input_ids, labels=None, attention_mask=attention_mask)
    loss = get_me_loss(outputs.logits, labels)

    return loss


def get_me_loss(logits, labels):
    num_labels = logits.shape[-1]

    assert logits.shape[:-1] == labels.shape, "Logits and labels must have compatible shapes."

    # Adjust logits and labels to exclude the last token
    labels = labels[:, 1:].clone()  # (bs, seq_len - 1)
    logits = logits[:, :-1, :]  # (bs, seq_len - 1, vocab_size)

    soft_outputs = F.softmax(logits, dim=-1).view(-1, num_labels)  # (bs*seq_len, vocab_size)
    uniform_dist = torch.full_like(soft_outputs, 1.0 / num_labels).to(logits.device)  # (bs*seq_len, vocab_size)

    loss_mask = (labels != -100).view(-1)  # (bs*(seq_len - 1))

    kl_div = F.kl_div((soft_outputs + 1e-12).log(), uniform_dist, reduction='none').sum(-1)  # (bs*(seq_len - 1))

    masked_kl_div = kl_div * loss_mask  # (bs*(seq_len - 1))
    loss = masked_kl_div.sum() / loss_mask.sum()

    return loss
