import torch


def calculate_jump_scores(predictions, device):
    num_checkpoints = predictions.shape[-1]
    jump_scores = torch.ones((len(predictions))) * (-1)
    label_switches = torch.ones((len(predictions), num_checkpoints - 1)) * (-1)
    for i in range(len(predictions)):
        label_switch = predictions[i, 1:] == predictions[i, :-1]
        label_switches[i] = ~label_switch
        jp_score = 1 - label_switch.sum() / (predictions.shape[1] - 1)
        jump_scores[i] = jp_score
    return jump_scores


def get_a_t(predictions):
    final_pred = predictions[:, -1]
    final_pred_rep = final_pred.unsqueeze(-1).repeat(1, predictions.shape[1])
    a_t = 1 - (predictions == final_pred_rep).float()
    return a_t


def get_mean_a_t(predictions):
    final_pred = predictions[:, -1]
    final_pred_rep = final_pred.unsqueeze(-1).repeat(1, predictions.shape[1])
    a_t = 1 - (predictions == final_pred_rep).float()
    e_t = a_t.mean(0)
    v_t = a_t.var(0)
    # return e_t[:-1], v_t[:-1]
    return e_t, v_t


def calculate_nntd_max_score(predictions, device, args):
    num_checkpoints = predictions.shape[-1]

    linear_weights = torch.linspace(0, 1, num_checkpoints).to(device)
    weighting = linear_weights ** (args.nntd_k)

    last_predictions = predictions[:, -1]
    last_predictions_rep = last_predictions.unsqueeze(-1).repeat(1, num_checkpoints)
    a_ts = 1 - (last_predictions_rep == predictions).int()

    weighting_exp = weighting.unsqueeze(0).repeat(len(predictions), 1)
    weight_vals = weighting_exp * a_ts
    avg_score, _ = torch.max(weight_vals, dim=1)
    return avg_score


def calculate_nntd_sum_score(predictions, start, step, device, args):
    num_checkpoints = predictions.shape[-1]

    linear_weights = torch.linspace(0, 1, num_checkpoints).to(device)
    weighting = linear_weights ** (args.nntd_k)

    last_predictions = predictions[:, -1]
    last_predictions_rep = last_predictions.unsqueeze(-1).repeat(1, num_checkpoints)
    a_ts = 1 - (last_predictions_rep == predictions).int()

    weighting_exp = weighting.unsqueeze(0).repeat(len(predictions), 1)
    weight_vals = weighting_exp * a_ts
    weight_vals = weight_vals[:, start::step]
    avg_score = torch.sum(weight_vals, dim=1)
    return avg_score


def calculate_nntd_sum_score_noargs(predictions, start, step, device, k):
    num_checkpoints = predictions.shape[-1]

    linear_weights = torch.linspace(0, 1, num_checkpoints).to(device)
    weighting = linear_weights ** (k)

    last_predictions = predictions[:, -1]
    last_predictions_rep = last_predictions.unsqueeze(-1).repeat(1, num_checkpoints)
    a_ts = 1 - (last_predictions_rep == predictions).int()

    weighting_exp = weighting.unsqueeze(0).repeat(len(predictions), 1)
    weight_vals = weighting_exp * a_ts
    weight_vals = weight_vals[:, start::step]
    avg_score = torch.sum(weight_vals, dim=1)
    return avg_score
