from typing import Any
import torch
import argparse
import util
import os
import time

from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch import optim
from tqdm import tqdm
import numpy as np
import loader
import models
import config
from util import AverageStorage, ProgressMeter, accuracy
import torch.nn.functional as F
from sklearn import metrics
import collections
from  collections import OrderedDict
import math
import copy

def test(test_loader, model, criterion, device):
    # model_.load_state_dict(model_0)
    # model = model_.to(device)
    loss_meter = AverageStorage('Loss', ':4e')
    acc_meter = AverageStorage('Acc', ':.2%')
    loss2_meter = AverageStorage('Loss2', ':.4e')
    loss3_meter = AverageStorage('Loss3', ':.4e')
    # auc_meter = AverageStorage('AUC', ':4e')
    # pauc_meter = AverageStorage('pAUC',':4e')
    progress = ProgressMeter(total=len(test_loader), step=20, prefix='Testing',
                             meters=[loss_meter, acc_meter, loss2_meter, loss3_meter])
    model.eval()
    # error=None
    labels_list = []
    errors_list = []
    all_scores = []
    all_labels = []
    for i, samples in enumerate(test_loader):
        inputs, labels, _ = samples
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        length = inputs.shape[2]
        frame_num = 5
        n_inputs = None
        for j in range(inputs.shape[0]):
            k = 0
            if n_inputs == None:
                n_inputs = inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)
            else:
                  n_inputs = torch.cat((n_inputs, inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0) 
            for k in range(1, length - frame_num):
                n_inputs = torch.cat((n_inputs, inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0)   
        print(n_inputs.shape)
        n_labels = None
        for j in labels:
            if n_labels == None:
                n_labels = j.unsqueeze(0)
                for k in range(length-frame_num-1):
                    n_labels = torch.cat((n_labels, j.unsqueeze(0)), dim=0)
            else:
                for k in range(length-frame_num):
                    n_labels = torch.cat((n_labels, j.unsqueeze(0)), dim=0)    
        x_tilde, kld, kld0 = model(n_inputs, stage = "test")
        
        mask = (n_labels == 0)
        loss_1 = criterion(kld, n_labels) #loss1代表的是分类交叉熵损失
        # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        loss_2 = torch.mean(torch.square(x_tilde[mask] - n_inputs[mask]))/10
        loss_3 = torch.mean(kld0[mask])
        loss = loss_2 + loss_3

        all_scores.append(torch.mean(kld0,dim=1).detach().cpu())
        all_labels.append(n_labels.detach().cpu())
        
        # errors = torch.mean(torch.pow((inputs - x_tilde), 2), dim=2)
        # errors = torch.mean(errors, dim = 2)
   
        acc = accuracy(kld, n_labels)
        # labels_list.append(labels)   
        # errors_list.append(errors)   
        
        acc_meter.update(acc.item(), n_labels.size(0))
        loss_meter.update(loss.item(), n_inputs.size(0))
        loss2_meter.update(loss_2.item(), n_inputs.size(0))
        loss3_meter.update(loss_3.item(), n_inputs.size(0))

        progress.display(i)

    # all_scores = torch.cat(all_scores).numpy()  # 将所有分数合并为一个numpy数组
    # all_labels = torch.cat(all_labels).numpy()  # 将所有标签合并为一个numpy数组

    # # 计算AUC
    # auc_score = metrics.roc_auc_score(all_labels, all_scores)
    
    # print("AUC: ", auc_score)
    # thresholds = np.linspace(all_scores.min(), all_scores.max(), 1000)

    # # 初始化最佳准确度和对应的阈值
    # best_acc = 0
    # best_threshold = 0

    # # 遍历所有可能的阈值
    # for threshold in thresholds:
    #     # 根据当前阈值生成预测标签
    #     ll1 = (all_labels == 1)
    #     ll2 = (all_labels == 0)
    #     preds = (all_scores >= threshold).astype(int)

    #     # 计算准确度
    #     acc1 = np.mean(preds[ll1] == all_labels[ll1])
    #     acc2 = np.mean(preds[ll2] == all_labels[ll2])
    #     acc = acc1 + acc2

    #     # 如果当前准确度更高，则更新最佳准确度和阈值
    #     if acc > best_acc:
    #         best_acc = acc
    #         best_threshold = threshold

    
    
    # print("Best Threshold:", best_threshold)
    # print("Best Accuracy:", best_acc)
    
    return loss_meter, acc_meter



def train(train_loader, align_loader, model, model_, criterion, optimizer, scheduler, device, epoch):
    loss_meter = AverageStorage('Loss', ':.4e')
    acc_meter = AverageStorage('Acc', ':.2%')
    loss2_meter = AverageStorage('Loss2', ':.4e')
    loss3_meter = AverageStorage('Loss3', ':.4e')
    loss_meter_p = AverageStorage('Lossp', ':.4e')
    loss_meter_a = AverageStorage('Lossa', ':.4e')
    acc_meter_p = AverageStorage('Accp', ':.2%')
    l0_meter = AverageStorage('Layer0', ':.4e')
    l1_meter = AverageStorage('Layer1', ':.4e')
    l2_meter = AverageStorage('Layer2', ':.4e')
    l3_meter = AverageStorage('Layer3', ':.4e')
    l4_meter = AverageStorage('Layer4', ':.4e')
    l5_meter = AverageStorage('Layer5', ':.4e')
    l6_meter = AverageStorage('Layer6', ':.4e')
    
    progress = ProgressMeter(total=len(train_loader), step=20, prefix='Training',
                             meters=[loss_meter, acc_meter, loss2_meter, loss3_meter, loss_meter_p, loss_meter_a, acc_meter_p])

    lll = [l0_meter, l1_meter, l2_meter, l3_meter, l4_meter, l5_meter, l6_meter]
    tt = {}
    model.train()
    test_iter = iter(align_loader)
    for ll, samples in enumerate(train_loader):

        tsamples = next(test_iter)
        tinputs, tlabels, _ = tsamples 
        inputs, labels, _ = samples
        len1 = len(inputs)
        # sinputs = torch.cat([inputs, tinputs])
        # slabels = torch.cat([labels, tlabels])
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        length = inputs.shape[2]
        frame_num = 5
        n_inputs = None
        for j in range(inputs.shape[0]):
            k = 0
            if n_inputs == None:
                n_inputs = inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)
            else:
                n_inputs = torch.cat((n_inputs, inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0) 
            for k in range(1, length - frame_num):
                n_inputs = torch.cat((n_inputs, inputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0)   
        n_labels = None
        for j in labels:
            if n_labels == None:
                n_labels = j.unsqueeze(0)
                for k in range(length-frame_num-1):
                    n_labels = torch.cat((n_labels, j.unsqueeze(0)), dim=0)
            else:
                for k in range(length-frame_num):
                    n_labels = torch.cat((n_labels, j.unsqueeze(0)), dim=0)    
        
        # print(x_tilde.shape)
        # print(inputs.shape)
        
        mask = (n_labels == 1)
        
        mask1 = (n_labels == 0)
        
        x_tilde, kld, kld0 = model(n_inputs, stage = 'test')
        # print(inputs)
        # log_var = log_var[mask]
        loss_1 = criterion(kld[mask1], n_labels[mask1])
        # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        loss_2 = torch.mean(torch.square(x_tilde[mask1] - n_inputs[mask1]))/10
        loss_3 = torch.mean(kld0[mask1]) #kld0正常样本的kl散度的均值
        loss = loss_3 + loss_2 
        optimizer.zero_grad()  # 1
        loss_meter.update(loss.item(), n_inputs.size(0))
        loss2_meter.update(loss_2.item(), n_inputs.size(0))
        loss3_meter.update(loss_3.item(), n_inputs.size(0))
        
        
        
        # #目标领域的数据预处理
        model_.train()
        tinputs = tinputs.to(device)
        tlabels = tlabels.to(device)
        
        length = tinputs.shape[2]
        frame_num = 5
        tn_inputs = None
        for j in range(tinputs.shape[0]):
            k = 0
            if tn_inputs == None:
                tn_inputs = tinputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)
            else:
                tn_inputs = torch.cat((tn_inputs, tinputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0) 
            for k in range(1, length - frame_num):
                tn_inputs = torch.cat((tn_inputs, tinputs[j, 0, k : k + frame_num, :].unsqueeze(0).unsqueeze(0)), dim=0)   
        tn_labels = None
        for j in tlabels:
            if tn_labels == None:
                tn_labels = j.unsqueeze(0)
                for k in range(length-frame_num-1):
                    tn_labels = torch.cat((tn_labels, j.unsqueeze(0)), dim=0)
            else:
                for k in range(length-frame_num):
                    tn_labels = torch.cat((tn_labels, j.unsqueeze(0)), dim=0)    
        
        # print(x_tilde.shape)
        # print(inputs.shape)
        
        tmask = (tn_labels == 0)
        tmask1 = (tn_labels == 1)
        # # print(inputs)
        # # log_var = log_var[mask]
        tx_tilde, tkld, tkld0 = model_(tn_inputs, "test")
        loss_1p = criterion(tkld[tmask], tn_labels[tmask])
        
        # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        loss_2p = torch.mean(torch.square(tx_tilde[tmask] - tn_inputs[tmask]))/10
        loss_3p = torch.mean(tkld0[tmask]) #kld0正常样本的kl散度的均值
        loss_p = loss_2p + loss_3p 
        
        accp = (accuracy(tkld[tmask], tn_labels[tmask]) + accuracy(tkld[tmask1],tn_labels[tmask1]))/2
        acc_meter_p.update(accp.item(), tn_labels.size(0))
        loss_meter_p.update(loss_p.item(), tn_inputs.size(0))
        loss = loss + loss_p
        loss.backward()
 
        
        total_diff = 0
        count = 0

        # 为了简化，我们使用一个字典存储模型B的参数和梯度
        name = ('AE_e.0.0.weight', 'AE_e.1.0.weight', 'AE_e.2.0.weight', 'AE_e.3.0.weight', 'fc_mu.weight', 'fc_var.weight','ff.weight')
        name1 = ('AE_e.0.1.weight', 'AE_e.1.1.weight', 'AE_e.2.1.weight', 'AE_e.3.1.weight')
        name2 = ('AE_e.0.0.bias', 'AE_e.1.0.bias', 'AE_e.2.0.bias', 'AE_e.3.0.bias', 'fc_mu.bias', 'fc_var.bias', 'ff.bias')
        name3 = ('AE_e.0.1.bias', 'AE_e.1.1.bias', 'AE_e.2.1.bias', 'AE_e.3.1.bias')
        params_B = {}
        BB = {}
        for i,j in model_.named_parameters():
            BB[i] = copy.deepcopy(j)
            if i in name:
                params_B[i] = copy.deepcopy(j)
        # 存储模型B的参数
        params_A = {}
        AA = {}
        for i,j in model.named_parameters():
            AA[i] = copy.deepcopy(j)
            if i in name:
                params_A[i] = copy.deepcopy(j)
        # 更新参数        
        optimizer.step()
        

        #重新计算
        nparams_B = {}
        for i,j in model_.named_parameters():
            if i in name:
                nparams_B[i] = j
        # 存储模型B的参数
        nparams_A = {}
        for i,j in model.named_parameters():
            if i in name:
                nparams_A[i] = j
        
        # 遍历求和
        for i in name:
            # grad_A = torch.mean(torch.abs(nparams_A[i] - params_A[i]), dim=1)
            # grad_A = (grad_A-torch.min(grad_A))/(torch.max(grad_A)-torch.min(grad_A)) #通过最大最小归一化
            grad_A = torch.linalg.norm(nparams_A[i] - params_A[i],dim=1)
            # grad_B = torch.mean(torch.abs(nparams_B[i] - params_B[i]), dim=1)
            # grad_B = (grad_B-torch.min(grad_B))/(torch.max(grad_B)-torch.min(grad_B))
            grad_B = torch.linalg.norm(nparams_A[i] - params_A[i],dim=1)
            # diff = (grad_A - grad_B).abs().sum()
            diff = 1-F.cosine_similarity(grad_A, grad_B, dim=0)
            total_diff += diff
            count += 1      

        # 如果需要计算平均梯度差值
        average_diff = total_diff / count /500
        # acc = accuracy(tkld, tn_labels)
        # # myhook.remove()
        
        loss_meter_a.update(average_diff.item(), n_inputs.size(0))
        average_diff.backward()
        # optimizer.zero_grad()  # 1
        # loss_.backward(retain_graph=True)
        
        
        
        for i,j in model.named_parameters():
            j = AA[i]
        for  i,j in model_.named_parameters():
            j = BB[i]     


        optimizer.step()
        
   
    
        
        optimizer.zero_grad()
        x_tilde, kld, kld0 = model(n_inputs, stage = 'test')
        loss_s = criterion(kld[mask1], n_labels[mask1])
        tx_tilde, tkld, tkld0 = model_(tn_inputs, stage = 'test')
        loss_t = criterion(tkld[tmask], tn_labels[tmask])
        
        accp = (accuracy(tkld[tmask], tn_labels[tmask]) + accuracy(tkld[tmask1],tn_labels[tmask1]))/2
        acc_meter_p.update(accp.item(), tn_labels.size(0))
        jj = {}
        lossq = loss_s + loss_t
        lossq.backward()
        # cf=0
        # lll = 0
        # tt = {}
        # nn = {}
        # for i,j in model.named_parameters():
        #     if i[0:4] == 'AE_e' or i[0:2] == 'em' or i[0:2] == 'ff' or i[0:2] == 'fc':
        #          jj[i] = j.grad
        # for i,j in model_.named_parameters():
        #     if i[0:4] == 'AE_e' or i[0:2] == 'em' or i[0:2] == 'ff' or i[0:2] == 'fc':
                
        #         difference = torch.abs(j.grad - jj[i])
        #         difference_flat = difference.flatten()
               
        #         num_elements = difference_flat.numel() # 总元素数
        #         k = int(0.05 * num_elements) # 前 10%

        #         # 使用 torch.topk 获取最大的 k 个元素及其索引
        #         topk_values, topk_indices = torch.topk(difference_flat, k)
        #         tt[i] = topk_indices
        if tt == {}:
            signal = 0
        else:
            signal = 1    
        #余弦距离筛选
        for i,j in model.named_parameters():
            if i in name:
                jj[i] = j.grad
        iii = 0
        for i,j in model_.named_parameters():
            if i in name:
                cosine_sim = F.cosine_similarity(j.grad, jj[i], dim=1)
                mask = cosine_sim > 0
                if signal == 0:
                    error = torch.tensor(0)
                else:    
                    error = torch.sum(torch.abs(mask.unsqueeze(1).int()-tt[i].int()))
                lll[iii].update(error.item(), 128)
                tt[i] = mask.unsqueeze(1)
                iii = iii + 1        
        
        # lll.backward()
        # optimizer.step()
        optimizer.zero_grad()
        x_tilde, kld, kld0 = model(n_inputs, stage = 'test')
        # tx_tilde, tkld, tkld0 = model_(tn_inputs, stage = "train")
        loss1 = criterion(kld, n_labels) 
        acc = accuracy(kld, n_labels)
        acc_meter.update(acc.item(), n_inputs.size(0))
        loss1.backward()
        
        kk = {}
        # for i,j in model.named_parameters():
        #     if i[0:4] == 'AE_e' or i[0:2] == 'em' or i[0:2] == 'ff' or i[0:2] == 'fc':
        #         gflat = j.grad.flatten()
        #         gflat[tt[i]] = 0
        #         j.grad = gflat.view_as(j.grad)
        #     kk[i] = j.grad
        for i,j in model.named_parameters():
            if i[0:2] == 'em':
                kk[i] = j.grad
            if i in name:
                index = tt[i]
                kk[i] = j.grad * index
            if i in name1:
                kk[i] = j.grad * index.squeeze(1)
            if i in name2:
                kk[i] = j.grad * index.squeeze(1)
                # print(i,' mask:', kk[i].shape)
            if i in name3:
                kk[i] = j.grad * index.squeeze(1)
           
        for i,j in model_.named_parameters():
            if i[0:4] == 'AE_e' or i[0:2] == 'em' or i[0:2] == 'ff' or i[0:2] == 'fc':
                j.grad = kk[i] 
        optimizer.step()        
        
        
        pp = {}
        for i,j in model.named_parameters():
            if i[0:2] == 'ff' or i[0:2] == 'em':
                pp[i] = j
        
        for i,j in model_.named_parameters():
            if i[0:2] == 'ff' or i[0:2] == 'em':
                j = pp[i]        
        
        # optimizer.step()  # 3
        # scheduler.step()     
        progress.display(ll)

    

    return loss_meter, acc_meter, acc_meter_p, lll

def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--model_name', type=str, default='cvae')
    parser.add_argument('--dataset_name', type=str, default='dcase')
    parser.add_argument('--SOTA', type=str, default='pretrained')
    parser.add_argument('--device', type=str, default='cuda:0')
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    
    print(device)
    
    # 模型、日志保存路径
    basic_dir = util.get_basic_dir(args.model_name, args.dataset_name, 'pretrained')
    log_dir = os.path.join(basic_dir, 'runs')
    model_path = os.path.join(basic_dir, 'model_ori_fan_un.pth')

    model_path1 = os.path.join(basic_dir, 'model_ori_toyconveyor_transfer1.pth')
    # 数据集
    train_loader, data_config = loader.load_data(args.dataset_name, data_type='train')
    test_loader, _ = loader.load_data(args.dataset_name, data_type='test')
    align_loader, _ = loader.load_data(args.dataset_name, data_type='align')
    
        
    # 模型 
    in_channels, num_classes = data_config['in_channels'], data_config['num_classes']
    model = models.load_model(args.model_name, in_channels=in_channels, num_classes=num_classes)
    model = model.to(device)
    
    model_ = models.load_model(args.model_name, in_channels=in_channels, num_classes=num_classes)
    model_ = model_.to(device)
    
    path0 = os.path.join(basic_dir, 'model_ori_fan_un0.pth') 
    path1 = os.path.join(basic_dir, 'model_ori_pump_un0.pth')
    torch.save(model, path0)
    torch.save(model_, path1)
            
    
    model=torch.load('C:\gkw\experiments\cv_demo_cvae_fan\output\cvae_dcase\pretrained\model_ori_pump_un0_001.pth', map_location=device)
    model_=torch.load('C:\gkw\experiments\cv_demo_cvae_fan\output\cvae_dcase\pretrained\model_ori_pump_un0_001.pth', map_location=device)   

    
    AA = {}
    for i,j in model.named_parameters():
        AA[i] = j
        

    for i,j in model_.named_parameters():
        j = AA[i]    
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(params=list(model.parameters())+list(model_.parameters()), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=config.num_epochs)

    writer = SummaryWriter(log_dir)


    # ----------------------------------------
    # each epoch
    # ----------------------------------------
    since = time.time()
    best_loss = None
    best_acc = None
    best_epoch_a = None
    
  
  
  
    
    for epoch in tqdm(range(config.num_epochs)):
        if args.SOTA == 'pretrained':
            loss, acc, accp, lll= train(train_loader, align_loader, model, model_, criterion, optimizer, scheduler, device, epoch)
        # writer.add_scalar(tag='training loss', scalar_value=loss.avg, global_step=epoch)
        # writer.add_scalar(tag='training loss2', scalar_value=loss2.avg, global_step=epoch)
        writer.add_scalar(tag='training acc1', scalar_value=acc.avg, global_step=  epoch)
        writer.add_scalar(tag='l0', scalar_value=lll[0].avg, global_step=epoch)
        writer.add_scalar(tag='l1', scalar_value=lll[1].avg, global_step=epoch)
        writer.add_scalar(tag='l2', scalar_value=lll[2].avg, global_step=epoch)
        writer.add_scalar(tag='l3', scalar_value=lll[3].avg, global_step=epoch)
        writer.add_scalar(tag='l4', scalar_value=lll[4].avg, global_step=epoch)
        writer.add_scalar(tag='l5', scalar_value=lll[5].avg, global_step=epoch)
        writer.add_scalar(tag='l6', scalar_value=lll[6].avg, global_step=epoch)
        loss, acc = test(test_loader, model_, criterion, device)
        writer.add_scalar(tag='test loss', scalar_value=loss.avg, global_step=epoch)
        

        # ----------------------------------------
        # save best model
        # ----------------------------------------
        # if best_loss is None or best_loss > loss.avg + lossp.avg:
        #     best_loss = loss.avg +lossp.avg
        #     best_epoch_a = epoch   
        #     torch.save(model, model_path)
        #     torch.save(model_, model_path1)

        if best_acc is None or best_acc < acc.avg:
            best_acc = acc.avg
            best_epoch_a = epoch
            torch.save(model, model_path)
            torch.save(model_, model_path1)
        
    
        
        scheduler.step()

    print('COMPLETE !!!')
    print('BEST acc', best_acc )
    print('BEST EPOCH', best_epoch_a)
    print('TIME CONSUMED', time.time() - since)


if __name__ == '__main__':
    main()