import torch
def loss_ff_V3 (conf, target, iso_scores, ff_values, max_query_template_size):
    """
    Loss function for maxclique prediction
    """
    loss_mask1 = (target[:,None]>=torch.arange(2, max_query_template_size,device=target.device)[None,:])
    loss_mask2 = (target[:,None]==torch.arange(2, max_query_template_size,device=target.device)[None,:])
    loss_mask3 = (target[:,None]+1==torch.arange(2, max_query_template_size,device=target.device)[None,:])


    loss_mask1_pre  = loss_mask1[:,:-1]
    loss_mask1_post = loss_mask1[:,1:]
    loss_mask1_both = loss_mask1_pre & loss_mask1_post
    compute_loss_bool = (loss_mask1_both).any(-1)


    loss_term1 = -((iso_scores*loss_mask1).sum(-1)/ loss_mask1.sum(-1))
    loss_term2 = torch.nn.ReLU()(iso_scores[loss_mask3] - iso_scores[loss_mask2] + conf.model.delta)

    pair_diffs_le_gt_size =  torch.nn.ReLU()\
                            (\
                            (iso_scores*loss_mask1)[:,:-1]\
                            - (iso_scores*loss_mask1)[:,1:]\
                            - conf.model.gamma * conf.model.delta\
                            )
    loss_term3 = pair_diffs_le_gt_size.sum(-1)[compute_loss_bool]/loss_mask1_both[compute_loss_bool].sum(-1)
    
    mse_loss_term = torch.nn.functional.mse_loss(ff_values, target)
    return loss_term1.mean() + conf.model.LAMBDA/2 * loss_term2.mean() + conf.model.LAMBDA/2 * loss_term3.mean() + conf.model.LAMBDA2 * mse_loss_term ,\
        {'perfect-match-loss': loss_term1.mean().item(), 'first-thresh-loss':loss_term2.mean().item(), 'lower-thresh-loss': loss_term3.mean().item(), 'train-ff-mse-loss': mse_loss_term.item()}


def loss_mse_only (conf, target, iso_scores, ff_values, max_query_template_size):
    """
    Loss function for maxclique prediction
    """

    mse_loss_term = torch.nn.functional.mse_loss(ff_values, target)
    return  mse_loss_term ,\
        {'perfect-match-loss': 0, 'first-thresh-loss':0, 'lower-thresh-loss': 0, 'train-ff-mse-loss': mse_loss_term.item()}


def loss_iso_only (conf, target, iso_scores, ff_values, max_query_template_size):
    """
    Loss function for maxclique prediction
    """
    loss_mask1 = (target[:,None]>=torch.arange(2, max_query_template_size,device=target.device)[None,:])
    loss_mask2 = (target[:,None]==torch.arange(2, max_query_template_size,device=target.device)[None,:])
    loss_mask3 = (target[:,None]+1==torch.arange(2, max_query_template_size,device=target.device)[None,:])


    loss_mask1_pre  = loss_mask1[:,:-1]
    loss_mask1_post = loss_mask1[:,1:]
    loss_mask1_both = loss_mask1_pre & loss_mask1_post
    compute_loss_bool = (loss_mask1_both).any(-1)

    loss_term1 = -((iso_scores*loss_mask1).sum(-1)/ loss_mask1.sum(-1))
    loss_term2 = torch.nn.ReLU()(iso_scores[loss_mask3] - iso_scores[loss_mask2] + conf.model.delta)
    pair_diffs_le_gt_size =  torch.nn.ReLU()\
                            (\
                            (iso_scores*loss_mask1)[:,:-1]\
                            - (iso_scores*loss_mask1)[:,1:]\
                            - conf.model.gamma * conf.model.delta\
                            )
    loss_term3 = pair_diffs_le_gt_size.sum(-1)[compute_loss_bool]/loss_mask1_both[compute_loss_bool].sum(-1)
    
    return loss_term1.mean() + conf.model.LAMBDA/2 * loss_term2.mean() + conf.model.LAMBDA/2 * loss_term3.mean(),\
        {'perfect-match-loss': loss_term1.mean().item(), 'first-thresh-loss':loss_term2.mean().item(), 'lower-thresh-loss': loss_term3.mean().item(), 'train-ff-mse-loss': 0}


