from sklearn.metrics import auc, roc_curve, confusion_matrix, precision_recall_curve
import numpy as np
import torch
import pickle
import matplotlib.pyplot as plt


def cal_false_alarm(gt, preds, threshold=0.5):
    preds = list(preds.cpu().detach().numpy())
    gt = list(gt.cpu().detach().numpy())

    preds = np.repeat(preds, 16)
    preds[preds < threshold] = 0
    preds[preds >= threshold] = 1
    tn, fp, fn, tp = confusion_matrix(gt, preds, labels=[0, 1]).ravel()

    far = fp / (fp + tn)

    return far


def test_func(dataloader, model, gt, dataset):
    with open('max.pkl', 'rb') as f:
        max_list = pickle.load(f)
    with torch.no_grad():
        model.eval()
        pred = torch.zeros(0).cuda()
        abnormal_preds = torch.zeros(0).cuda()
        abnormal_labels = torch.zeros(0).cuda()
        normal_preds = torch.zeros(0).cuda()
        normal_labels = torch.zeros(0).cuda()
        gt_tmp = torch.tensor(gt.copy()).cuda()
        # count = []
        for i, (v_input, label) in enumerate(dataloader):
            v_input = v_input.float().cuda(non_blocking=True)
            seq_len = torch.sum(torch.max(torch.abs(v_input), dim=2)[0] > 0, 1)

            logits, feat, _ = model(v_input, seq_len)
            # class_output = linear_model(v_input)
            #count.append(max(class_output.tolist()))
            logits = torch.mean(logits, 0)
            pred = torch.cat((pred, logits))
            labels = gt_tmp[: seq_len[0] * 16]
            if torch.sum(labels) == 0:
                normal_labels = torch.cat((normal_labels, labels))
                normal_preds = torch.cat((normal_preds, logits))
            else:
                abnormal_labels = torch.cat((abnormal_labels, labels))
                abnormal_preds = torch.cat((abnormal_preds, logits))
            gt_tmp = gt_tmp[seq_len[0] * 16:]

        def count_elements_below_threshold(lst, threshold):
            count = 0
            for element in lst:
                if element < threshold:
                    count += 1
            return count
        
        def count_elements_up_threshold(lst, threshold):
            count = 0
            for element in lst:
                if element > threshold:
                    count += 1
            return count
        
        pred = list(pred.cpu().detach().numpy())
        # with open('list_file.pkl', 'wb') as file:
        #     pickle.dump(count, file)
        # plot_scatter(count, max_list)
        # print('saved plot')
        # print(f'Correct {(count_elements_up_threshold(count[:140], 0.1)+count_elements_below_threshold(count[140:], 0.1))/len(count)}')
        # print(f'Correct {(count_elements_up_threshold(count[:140], 0.2)+count_elements_below_threshold(count[140:], 0.2))/len(count)}')
        # print(f'Correct {(count_elements_up_threshold(count[:140], 0.3)+count_elements_below_threshold(count[140:], 0.3))/len(count)}')
        # print(f'Correct {(count_elements_up_threshold(count[:140], 0.5)+count_elements_below_threshold(count[140:], 0.5))/len(count)}')
        n_far = cal_false_alarm(normal_labels, normal_preds)
        fpr, tpr, _ = roc_curve(list(gt), np.repeat(pred, 16))
        roc_auc = auc(fpr, tpr)
        pre, rec, _ = precision_recall_curve(list(gt), np.repeat(pred, 16))
        pr_auc = auc(rec, pre)

        if dataset == 'ucf-crime':
            return roc_auc, n_far
        elif dataset == 'xd-violence':
            return pr_auc, n_far
        elif dataset == 'shanghaiTech':
            return roc_auc, n_far
        else:
            raise RuntimeError('Invalid dataset.')

def plot_scatter(x, y):
    plt.figure()
    plt.scatter(x[:140], y[:140], color='red', label='First 140 points')
    plt.scatter(x[140:], y[140:], color='blue', label='Points after 140')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('Scatter Plot')
    plt.legend()
    plt.show()
    plt.savefig('dot.png')