import torch
import argparse
import util
import os
import time
import pprint

from torch.utils.tensorboard import SummaryWriter
from torch import nn
import torch.nn.functional as F
import numpy as np

from sklearn import metrics
import loader
import models
import config
from util import AverageStorage, ProgressMeter, accuracy, ClassAccuracy
import collections
from collections import OrderedDict
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def test(test_loader, align_loader, model, model_, criterion, device, num_classes):
    loss_meter = AverageStorage('Loss', ':4e')
    acc_meter = AverageStorage('Acc', ':.2%')
    progress = ProgressMeter(total=len(test_loader), step=20, prefix='Testing',
                             meters=[loss_meter, acc_meter])
    model.eval()
    # labels_list = []
    # errors_list = []
    results = None
    label = None
    pred = None
    all_scores = []
    all_labels = []
    features_combined = None
    labels_combined = None
    for i, samples in enumerate(test_loader):
        inputs, labels, _ = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        length = inputs.shape[2]
        frame_num = 5
        n_inputs = None
        for j in range(inputs.shape[0]):
            k = 0
            if n_inputs == None:
                n_inputs = inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)
            else:
                n_inputs = torch.cat((n_inputs, inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0) 
            for k in range(1, length - frame_num):
                n_inputs = torch.cat((n_inputs, inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0)   

        n_labels = None
        for j in labels:
            if n_labels == None:
                n_labels = j.unsqueeze(0)
                for k in range(length-frame_num-1):
                    n_labels = torch.cat((n_labels, j.unsqueeze(0)), dim=0)
            else:
                for k in range(length-frame_num):
                    n_labels = torch.cat((n_labels, j.unsqueeze(0)), dim=0)    
        x_tilde, kld, kld0 = model(n_inputs, "test")

        if features_combined == None:
            features_combined = kld0
        else:
            features_combined = torch.cat((features_combined, kld0), 0)  

        if labels_combined == None:
            labels_combined = n_labels
        else:
            labels_combined = torch.cat((labels_combined,n_labels), 0)    
        errors = torch.mean(kld0, dim=1)
        loss_1 = criterion(kld, n_labels) #loss1代表的是分类交叉熵损失
        # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss_2 = F.mse_loss(x_tilde,n_inputs)/1000
        loss_3 = torch.mean(kld0)
        loss = loss_2 + loss_3

        all_scores.append(torch.mean(kld0,dim=1).detach().cpu())
        all_labels.append(n_labels.detach().cpu())
        # if results == None:
        #     results = errors
        # else:
        #     results = torch.cat((results, errors), dim = 0)   
        # if label == None:
        #     label = n_labels
        # else:
        #     label = torch.cat((label, n_labels), dim = 0)
   
        # if pred == None:
        #     pred = kld.argmax(dim=1)
        # else: 
        #     pred = torch.cat((pred,kld.argmax(dim=1)))    
        # errors = torch.mean(torch.pow((inputs - x_tilde), 2), dim=2)
        # errors = torch.mean(errors, dim = 2)
   
        acc = accuracy(kld, n_labels)   
        # errors_list.append(errors)   
        acc_meter.update(acc.item(), labels.size(0))
        loss_meter.update(loss.item(), inputs.size(0))

        progress.display(i)
    all_scores = torch.cat(all_scores).numpy()  # 将所有分数合并为一个numpy数组
    all_labels = torch.cat(all_labels).numpy()  # 将所有标签合并为一个numpy数组

    # 计算AUC
    auc_score = metrics.roc_auc_score(all_labels, all_scores)
    
    print("AUC: ", auc_score)
    thresholds = np.linspace(all_scores.min(), all_scores.max(), 1000)

    # 初始化最佳准确度和对应的阈值
    best_acc = 0
    best_threshold = 0

    # 遍历所有可能的阈值
    for threshold in thresholds:
        # 根据当前阈值生成预测标签
        ll1 = (all_labels == 1)
        ll2 = (all_labels == 0)
        preds = (all_scores >= threshold).astype(int)

        # 计算准确度
        acc1 = np.mean(preds[ll1] == all_labels[ll1])
        acc2 = np.mean(preds[ll2] == all_labels[ll2])
        acc = acc1 + acc2

        # 如果当前准确度更高，则更新最佳准确度和阈值
        if acc > best_acc:
            best_acc = acc
            best_threshold = threshold

    
    
    print("Best Threshold:", best_threshold)
    print("Best Accuracy:", best_acc)
    # acc = 0 
    # num_0 = 0
    # num_1 = 1
    # sum_0 = 0
    # sum_1 = 1
    # for i in range(len(pred)):
    #     if label[i] == 0:
    #         num_0 = num_0 + 1
    #         if pred[i] < 0.5:
    #             sum_0 = sum_0 + 1
    #     else:
    #         num_1 = num_1 + 1
    #         if pred[i] >= 0.5:
    #             sum_1 = sum_1 + 1
    for i, samples in enumerate(align_loader):
        inputs, labels, _ = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        length = inputs.shape[2]
        frame_num = 5
        n_inputs = None
        for j in range(inputs.shape[0]):
            k = 0
            if n_inputs == None:
                n_inputs = inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)
            else:
                n_inputs = torch.cat((n_inputs, inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0) 
            for k in range(1, length - frame_num):
                n_inputs = torch.cat((n_inputs, inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0)   

        n_labels = None
        for j in labels:
            if n_labels == None:
                n_labels = j.unsqueeze(0)
                for k in range(length-frame_num-1):
                    n_labels = torch.cat((n_labels, j.unsqueeze(0)), dim=0)
            else:
                for k in range(length-frame_num):
                    n_labels = torch.cat((n_labels, j.unsqueeze(0)), dim=0)    
        x_tilde, kld, kld0 = model_(n_inputs, "test")


        errors = torch.mean(kld0, dim=1)
        loss_1 = criterion(kld, n_labels) #loss1代表的是分类交叉熵损失
        # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss_2 = F.mse_loss(x_tilde,n_inputs)/1000
        loss_3 = torch.mean(kld0)
        loss = loss_2 + loss_3  
        n_labels = n_labels + 2
        features_combined = torch.cat((features_combined, kld0), 0)
        labels_combined = torch.cat((labels_combined,n_labels), 0)          
    # print("Normal ACC:", sum_0/num_0)
    # print("Anomaly ACC", sum_1/num_1)   
    # print("len(label):",len(label))
    # print("ddd:", length - frame_num)                 
    # mean = torch.median(results)
    # ddd = length - frame_num
    # lab = torch.zeros(len(label))
    # for i in range(len(results)):
    #     if results[i] >= mean:
    #         lab[i] = 1
    # acc = 0
    # for i in range(int(len(label)/ddd)):
    #     prob = torch.sum(lab[(i*ddd):((i+1)*ddd-1)])/ddd 
    #     if prob>0.5 and label[i*ddd] == 1:
    #         acc = acc + 1
    #     elif prob <= 0.5 and label[i*ddd] == 0:
    #         acc = acc + 1
    # acc = acc/(len(label)/ddd)
    # print("acc", acc)             
    # acc_true = 0
    # label_true = None
    # ddd = length - frame_num
    # for i in range(int(len(label)/ddd)):
    #     prob = torch.sum(pred[(i*ddd):((i+1)*ddd-1)])/ddd 
    #     if prob > 0.35 and label[i*ddd] == 1:
    #         acc_true = acc_true + 1
    #     elif prob <= 0.35 and label[i*ddd] == 0:
    #         acc_true = acc_true + 1
    # acc_true = acc_true/(len(label)/ddd)       
    # print("true acc:", acc_true)        
    # ll = torch.cat(labels_list).cpu().detach().numpy()
    # ee = torch.cat(errors_list).cpu().detach().numpy()
    # acc = 0
    # mm = np.median(ee)
    # for i in range(len(ee)):
    #     if ee[i] > mm and ll[i] == 1:
    #         acc = acc + 1
    #     elif ee[i] <= mm and ll[i] ==0:
    #         acc = acc + 1   
    # acc= acc/len(ee) 
    features_combined = features_combined.detach().cpu().numpy()
    labels_combined = labels_combined.cpu().numpy()
    tsne = TSNE(n_components=2, random_state=0, perplexity=15)
    transformed_features = tsne.fit_transform(features_combined)
    colors = ['r', 'g', 'b', 'y']
    markers = ['o', '^', 's', 'x'] # 可以选择不同的形状

    for i in range(4):
        # 通过标签筛选出每一类
        indices = labels_combined == i
        plt.scatter(transformed_features[indices, 0], transformed_features[indices, 1], c=colors[i], marker=markers[i], label=f'Class {i+1}')

    plt.legend()
    plt.show()
    return loss_meter, acc_meter

    

def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--model_name', type=str, default='cvae')
    parser.add_argument('--dataset_name', type=str, default='dcase')
    parser.add_argument('--SOTA', type=str, default='pretrained')
    parser.add_argument('--model_type', type=str, default='ori')
    parser.add_argument('--device', type=str, default='cuda:0')
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    
    # 模型、日志保存路径
    basic_dir = util.get_basic_dir(args.model_name, args.dataset_name, args.SOTA)
    log_dir = os.path.join(basic_dir, 'runs')
    model_path = os.path.join(basic_dir, 'model_' + args.model_type + '.pth')

    # 数据集
    test_loader, _ = loader.load_data(args.dataset_name, data_type='test')
    align_loader, _ = loader.load_data(args.dataset_name, data_type='align')
    train_loader, _ = loader.load_data(args.dataset_name, data_type='train')
    data_config = config.get_data_config(args.dataset_name)

    # 模型
    in_channels, num_classes = data_config['in_channels'], data_config['num_classes']
    # model = models.load_model(args.model_name, in_channels=in_channels, num_classes=num_classes)
    # model_unf = model_unf.to(device)
    # # model.load_state_dict(torch.load(model_path, map_location=device))
    # model = torch.load(model_path, map_location=device)
    
    model_unf = models.load_model(args.model_name, in_channels=in_channels, num_classes=num_classes)
    model_unf = model_unf.to(device)
    # model.load_state_dict(torch.load(model_path, map_location=device))
      
    # model_unp = torch.load('C:\gkw\experiments\cv_demo_cvae_fan\modelpth\model_ori_unpump.pth', map_location=device)
    # model_suf = torch.load('C:\gkw\experiments\cv_demo_cvae_fan\modelpth\model_ori_sufan.pth', map_location=device)
    # model_sup = torch.load('C:\gkw\experiments\cv_demo_cvae_fan\modelpth\model_ori_supump.pth', map_location=device)
    # model_X = torch.load('C:\gkw\experiments\cv_demo_cvae_fan\modelpth\model_ori_transfer_0.pth', map_location=device)
    # name = list()
    

    
  
    model=torch.load('C:\gkw\experiments\cv_demo_cvae_fan_ppp\output\cvae_dcase\pretrained\model_ori_slider_transfer.pth', map_location=device)
    model = model.to(device)
    model_=torch.load('C:\gkw\experiments\cv_demo_cvae_fan_ppp\output\cvae_dcase\pretrained\model_ori_fan_un.pth', map_location=device)
    model_ = model_.to(device)

   
    

    
    criterion = nn.CrossEntropyLoss()


    # ----------------------------------------
    # each epoch
    # ----------------------------------------
    since = time.time()

    loss, acc = test(test_loader, align_loader, model, model_, criterion, device, num_classes)
    
    print('-' * 20 + 'Clean Dataset Test' + '-' * 20)
    # print(auc,pauc) 

    print('TIME CONSUMED', time.time() - since)

    since = time.time()



if __name__ == '__main__':
    main()