import torch

def cal_acc_1(pred, truth, length):
    """
    pred: T, B
    truth; T, B
    """

    num = 0
    summ = 0
    pred = pred.permute(1, 0)
    truth = truth.permute(1, 0)
    B, T = pred.shape
    for i in range(B):

        for t in range(length[i]):
            if pred[i][t] == truth[i][t] and truth[i][t] != 0:
                num += 1
            if truth[i][t] != 0:
                summ += 1

        tmp_pred = pred[i][:length[i]].cuda().cpu().numpy().tolist()
        tmp_truth = truth[i][:length[i]].cuda().cpu().numpy().tolist()

        num += len(set(tmp_pred) & set(tmp_truth))
    # print(num, sum(length))
    # return num / sum(length)
    return num / summ

def cal_acc_2(pred, truth, length):
    pred = pred.permute(1, 0)
    truth = truth.permute(1, 0)

    num = 0
    summ = 0

    B, T = pred.shape

    for i in range(B):
        len_i = length[i]
        pred_i = pred[i][:len_i]
        truth_i = truth[i][:len_i]

        non_zero_mask = (truth_i != 0)
        num += torch.sum((pred_i == truth_i) & non_zero_mask).item()
        summ += torch.sum(non_zero_mask).item()

        common_elements = len(set(pred_i.cpu().numpy().tolist()) & set(truth_i.cpu().numpy().tolist()))
        num += common_elements

    return num / summ if summ > 0 else 0

def cal_acc(pred, truth, length):
    pred = pred.permute(1, 0)
    truth = truth.permute(1, 0)

    num = 0
    summ = 0

    B, T = pred.shape

    non_zero_mask = (truth != 0)
    num += torch.sum((pred == truth) & non_zero_mask).item()
    summ += torch.sum(non_zero_mask).item()

    return num / summ if summ > 0 else 0


def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs