import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
import random
import matplotlib.pyplot as plt
import scipy.io as sio
torch.autograd.set_detect_anomaly(True)

from datadPreprocessing import Get_data_set
from utils import random_mini_batches_standardtwoModality, Get_mini_batches, Get_mini_batches_standard
from data_show import calculate_acc, calculate_acc_Dict
import lion_pytorch

class HGRDecNetwork(nn.Module):

    def __init__(self,iHSIz,iSARz,ipatchsize,ik_maxL):
        super(HGRDecNetwork, self).__init__()

        self.HSIz = iHSIz
        self.SARz = iSARz
        self.patchsize = ipatchsize
        self.k_maxL = ik_maxL
        en1Inputpara = int((ipatchsize + 1) / 4)
        self.en1InputLen = en1Inputpara * en1Inputpara


        self.resnet50_1 = ResNet50(self.HSIz)
        self.resnet50_2 = ResNet50(self.SARz)

        self.fc_en_1_x1 = nn.Sequential(nn.Linear(self.en1InputLen*64*(en1Inputpara+1), 64),
                                     nn.BatchNorm1d(64, momentum=0.9),
                                     nn.PReLU())
        self.fc_en_1_x2 = nn.Sequential(nn.Linear(self.en1InputLen*64*(en1Inputpara+1), 64),
                                        nn.BatchNorm1d(64, momentum=0.9),
                                        nn.PReLU())

        self.fc_en_2_x1 = nn.Sequential(nn.Linear(64, 32),
                                     nn.BatchNorm1d(32, momentum=0.9),
                                     nn.PReLU())
        self.fc_en_2_x2 = nn.Sequential(nn.Linear(64, 32),
                                        nn.BatchNorm1d(32, momentum=0.9),
                                        nn.PReLU())

        self.fc_en_3_x1 = nn.Sequential(nn.Linear(32, 16),
                                     nn.BatchNorm1d(16, momentum=0.9),
                                     nn.PReLU())
        self.fc_en_3_x2 = nn.Sequential(nn.Linear(32, 16),
                                        nn.BatchNorm1d(16, momentum=0.9),
                                        nn.PReLU())

        self.fc_en_4 = nn.Sequential(nn.Linear(16 * 2, self.k_maxL))

        self.fc_de_4 = nn.Sequential(nn.Linear(self.k_maxL, 16 * 2),
                                     nn.BatchNorm1d(32, momentum=0.9),
                                     nn.PReLU())

    def forward(self, x1, x2):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        x1 = x1.view(-1, self.HSIz, self.patchsize, self.patchsize).to(device)
        x2 = x2.view(-1, self.SARz, self.patchsize, self.patchsize).to(device)

        x1_conv1 = self.resnet50_1(x1)
        x2_conv1 = self.resnet50_2(x2)

        x1_conv2_shape = list(x1_conv1.size())
        x1_conv1 = x1_conv1.reshape(-1, x1_conv2_shape[1] * x1_conv2_shape[2] * x1_conv2_shape[3])

        x2_conv2_shape = list(x2_conv1.size())
        x2_conv1 = x2_conv1.reshape(-1, x2_conv2_shape[1] * x2_conv2_shape[2] * x2_conv2_shape[3])

        x1_en_1 = self.fc_en_1_x1(x1_conv1)
        x2_en_1 = self.fc_en_1_x2(x2_conv1)

        x1_en_2 = self.fc_en_2_x1(x1_en_1)
        x2_en_2 = self.fc_en_2_x2(x2_en_1)

        x1_en_3 = self.fc_en_3_x1(x1_en_2)
        x2_en_3 = self.fc_en_3_x2(x2_en_2)

        joint_layer = torch.cat((x1_en_3, x2_en_3), dim=1).to(device)
        
        x_en_4 = self.fc_en_4(joint_layer)
        x_de_4 = self.fc_de_4(x_en_4)
        return x_en_4, x1_conv1, x2_conv1, x1_en_1, x2_en_1, x1_en_2, x2_en_2, x1_en_3, x2_en_3, joint_layer, x_de_4

class ResNet50(nn.Module):
    def __init__(self, in_channels):
        super(ResNet50, self).__init__()

        self.resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
        self.resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.res50 = torch.nn.Sequential(*list(self.resnet.children())[:3])



    def forward(self, x):
        return self.res50(x)

# def HGRscore(f, g):                    #HGRscore2



class DeepCCA(nn.Module):                 #DeepCCA
    def __init__(self, input_dim1, input_dim2, hidden_dim, num_layers):
        super(DeepCCA, self).__init__()
        self.layers1 = nn.ModuleList([nn.Linear(input_dim1, hidden_dim)])
        for _ in range(num_layers - 1):
            self.layers1.append(nn.Linear(hidden_dim, hidden_dim))
        self.layers2 = nn.ModuleList([nn.Linear(input_dim2, hidden_dim)])
        for _ in range(num_layers - 1):
            self.layers2.append(nn.Linear(hidden_dim, hidden_dim))

    def forward(self, x1, x2):
        for layer in self.layers1:
            x1 = torch.relu(layer(x1))
        for layer in self.layers2:
            x2 = torch.relu(layer(x2))
        return x1, x2

def HGRscore(x1, x2):
# def  DeepCCAscore(x1, x2):
    results = []
    for row1, row2 in zip(x1, x2):
        result = torch.mean(torch.dot(row1, row2)) / (torch.norm(row1) * torch.norm(row2))
        results.append(result)
    corr = result
    tra = result
    result = torch.mean(torch.tensor(results))
    return result, corr, tra


# class SoftCCA(nn.Module):
#     def __init__(self, input_dim1, input_dim2, hidden_dim, num_layers, lambda_=1.0):
#         super(SoftCCA, self).__init__()
#         self.layers1 = nn.ModuleList([nn.Linear(input_dim1, hidden_dim)])
#         for _ in range(num_layers - 1):
#             self.layers1.append(nn.Linear(hidden_dim, hidden_dim))
#         self.layers2 = nn.ModuleList([nn.Linear(input_dim2, hidden_dim)])
#         for _ in range(num_layers - 1):
#             self.layers2.append(nn.Linear(hidden_dim, hidden_dim))
#         self.lambda_ = lambda_
#
#     def forward(self, x1, x2):
#         for layer in self.layers1:
#             x1 = torch.relu(layer(x1))
#         for layer in self.layers2:
#             x2 = torch.relu(layer(x2))
#         return x1, x2
#
#     def stochastic_decorrelation_loss(self, z1, z2):
#         m = z1.shape[0]  # mini-batch size
#         k = z1.shape[1]  # number of neurons/feature channels
#
#         C_mini_z1 = (1 / (m - 1)) * torch.matmul(z1.T, z1)
#         C_mini_z2 = (1 / (m - 1)) * torch.matmul(z2.T, z2)
#
#         C_accu_z1 = torch.zeros_like(C_mini_z1)
#         C_accu_z2 = torch.zeros_like(C_mini_z2)
#         c_z1 = 0
#         c_z2 = 0
#
#         for i in range(m):
#             C_accu_z1 = torch.add(C_accu_z1, C_mini_z1)
#             C_accu_z2 = torch.add(C_accu_z2, C_mini_z2)
#             c_z1 += 1
#             c_z2 += 1
#
#         C_appx_z1 = C_accu_z1 / c_z1
#         C_appx_z2 = C_accu_z2 / c_z2
#
#         sdl_loss_z1 = 0
#         sdl_loss_z2 = 0
#         for i in range(k):
#             for j in range(k):
#                 if i!= j:
#                     sdl_loss_z1 += torch.abs(C_appx_z1[i][j])
#                     sdl_loss_z2 += torch.abs(C_appx_z2[i][j])
#
#         return sdl_loss_z1 + sdl_loss_z2
#
# def HGRscore(x1, x2):
# # def compute_correlation_soft_cca(x1, x2):
#     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#     x1 = x1.clone().detach().float().to(device)
#     x2 = x2.clone().detach().float().to(device)
#
#     model = SoftCCA(input_dim1=x1.shape[1], input_dim2=x2.shape[1], hidden_dim=64, num_layers=3, lambda_=0.1)
#     model.to(device)
#
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#
#     for epoch in range(100):
#         optimizer.zero_grad()
#         z1, z2 = model(x1, x2)
#
#     z1_flat = z1.view(-1)
#     z2_flat = z2.view(-1)
#     corr = torch.mean(torch.dot(z1_flat, z2_flat)) / (torch.norm(z1_flat) * torch.norm(z2_flat))
#     tra = corr
#     result = 1-corr
#
#     return result, corr, tra


# Define Loss Function
class MuCNN_Loss(nn.Module):
    def __init__(self, beta_reg=0.1):
        super(MuCNN_Loss, self).__init__()
        self.reg = beta_reg
        self.CEL = nn.CrossEntropyLoss()
        self.MSE = nn.MSELoss()

    def forward(self, y_es, y_re, c1, c2, c3, c4, c5, c6, c7, c8, r1, r2):
        Lce = self.CEL(y_es, y_re)
        Lrec = self.MSE(r1, r2)

        Lhgr0, __, __ = HGRscore(c1, c2)
        Lhgr1, __, __ = HGRscore(c3, c4)
        Lhgr2, __, __ = HGRscore(c5, c6)
        Lhgr3, corr, tra = HGRscore(c7, c8)


        return 1.0 * Lce + 0.0125 * Lhgr0 + 0.025 * Lhgr1 + 0.05 * Lhgr2 + 0.1 * Lhgr3  + 0.1 * Lrec, corr, tra


def train_mynetwork_HighMemory(x1_train_set, x2_train_set, x1_test_set, x2_test_set, y_train_set, y_test_set,
                               HSIz, SARz, patchsize, k_maxL, pad_width, kaugmentation, filename_prefix, save_path,
                               min_AA_Save=0.7, min_AO_Save=0.6, learning_rate=1e-4, num_epochs=500, minibatch_size=32,
                               seed=-1, is_scheduler=1, is_show=0, is_random=1):
    t0 = time.time()
    # Single GPU or CPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if seed == -1:
        seed = int(time.time())
    if is_random:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)



    # Define Loss Value
    train_losslist = []
    val_losslist = []
    train_acc = []
    val_acc = []
    val_aa = []
    max_ao = 0.
    max_aa = 0.
    max_aoplusaa = min_AO_Save + min_AA_Save
    # Instantiate Object
    model = HGRDecNetwork(HSIz, SARz, patchsize, k_maxL)
    model.to(device)

    # Define Loss Function and Optimizer
    criterion = MuCNN_Loss()

    optimizer = lion_pytorch.Lion(model.parameters(), learning_rate)

    if is_scheduler:

        lr_s = optim.lr_scheduler.OneCycleLR(optimizer,max_lr=learning_rate, steps_per_epoch=int(num_epochs/4),
                                             epochs=5, anneal_strategy='cos')

    # Train and Evaluate Model

    vallableminibatches = Get_mini_batches(x1_test_set, x2_test_set, y_test_set, minibatch_size*20)

    t1 = time.time()
    spend1 = t1 - t0
    print('Data minibatch Time: {}'.format(spend1))
    with open(save_path + '/' + filename_prefix+ 'logger.txt', 'a') as file:
        file.write('Data minibatch Time: {}\n'.format(spend1))

    test_joint_layer_all = []
    for epoch in range(num_epochs + 1):
        t0 = time.time()

        model.train()

        train_correct = 0.
        train_loss_total = 0.0

        corr_all = 0.
        tra_all = 0.
        loss_val_all = 0.
        corr_val_all = 0.
        tra_val_all = 0.

        seed = seed + 1
        bmatsaved = 0
        minibatches = random_mini_batches_standardtwoModality(x1_train_set, x2_train_set, y_train_set, minibatch_size,
                                                              seed)

        for minibatch in minibatches:
            (batch_x1, batch_x2, batch_y) = minibatch
            batch_x1 = torch.from_numpy(batch_x1).to(device)
            batch_x2 = torch.from_numpy(batch_x2).to(device)
            batch_y = torch.from_numpy(batch_y).to(device)
            # Forward Propagation
            joint_layer, c1, c2, c3, c4, c5, c6, c7, c8, r1, r2 = model(batch_x1, batch_x2)
            loss, corr_train, tra_train = criterion(joint_layer, batch_y, c1, c2, c3, c4, c5, c6, c7, c8, r1, r2)

            train_loss_total += loss.item()
            corr_all += corr_train
            tra_all += tra_train


            train_correct = train_correct + calculate_acc(joint_layer, batch_y)

            # Backward Propagation
            backpropagated = False
            optimizer.zero_grad()
            try:
                loss.backward(retain_graph=True)
                backpropagated = True
            except RuntimeError as e:
                if "Trying to backward through the graph a second time" in str(e):
                    optimizer.zero_grad()
                    loss.backward()
                    backpropagated = True
                else:
                    raise e
            if backpropagated:
                optimizer.step()


        train_loss = train_loss_total / len(minibatches)
        train_accuracy = train_correct / len(minibatches)
        corr = corr_all / len(minibatches)
        tra = tra_all / len(minibatches)


        if is_scheduler:
            lr_s.step()


        # Test Function
        model.eval()

        test_joint_layer_all = []
        for vallableminibatch in vallableminibatches:


            (batch_x1, batch_x2, batch_y) = vallableminibatch
            batch_x1 = torch.from_numpy(batch_x1).to(device)
            batch_x2 = torch.from_numpy(batch_x2).to(device)
            batch_y = torch.from_numpy(batch_y).to(device)

            with torch.no_grad():
                test_joint_layer, c1, c2, c3, c4, c5, c6, c7, c8, r1, r2 = model(batch_x1, batch_x2)
                loss_val, corr_val, tra_val = criterion(test_joint_layer, batch_y, c1, c2, c3, c4, c5, c6, c7, c8, r1,
                                                        r2)
            test_joint_layer_all.extend(test_joint_layer.cpu().detach().numpy())
            loss_val_all += loss_val.item()
            corr_val_all += corr_val
            tra_val_all += tra_val


        val_loss = loss_val_all / len(vallableminibatches)
        val_corr = corr_val_all / len(vallableminibatches)
        val_tra = tra_val_all / len(vallableminibatches)

        acc_Dict = calculate_acc_Dict(test_joint_layer_all, y_test_set, k_maxL)
        test_AO = acc_Dict['AO']
        test_AA = acc_Dict['AA']

        if (((test_AO > max_ao or test_AA > max_aa) and (test_AO >= min_AO_Save or test_AA >= min_AA_Save)) or (
                test_AO + test_AA) > max_aoplusaa):
            bmatsaved = 1
            torch.save(model, save_path + '/best_AO' + str(test_AO) + '_AA' + str(test_AA) + '.pt')
            sio.savemat(save_path + '/feature_AO' + str(test_AO) + '_AA' + str(test_AA) + '.mat',
                        {'feature': test_joint_layer_all, 'acc_Dict': acc_Dict, 'pad_width': pad_width})
            if test_AO > max_ao:
                max_ao = test_AO
            if test_AA > max_aa:
                max_aa = test_AA
            if (test_AO + test_AA) > max_aoplusaa:
                max_aoplusaa = test_AO + test_AA

        t1 = time.time()
        spend1 = t1 - t0

        if (epoch + 1) % 1 == 0:
            print(
                'Epoch [{}/{}], corr:{:.4f}, trace:{:.4f}, Loss: {:.4f}, Accuracy: {:.4f}, Val_Loss: {:.4f}, Val_trace: {:.4f}, Val_corr: {:.4f}, Val_Accuuracy: {:.4f}, Val_AA: {:.4f}, Learning_rate: {:.4e}, Epoch time: {:.4f}s'.format(
                    epoch, num_epochs, corr, tra, train_loss, train_accuracy, val_loss, val_tra, val_corr, test_AO,
                    test_AA, optimizer.param_groups[0]["lr"], spend1))
            print('max_ao:{:.4f}, max_aa:{:.4f}, Bsave: {}'.format(max_ao, max_aa, bmatsaved))
            train_losslist.append(train_loss)
            train_acc.append(train_accuracy)
            val_losslist.append(val_loss)
            val_acc.append(test_AO)
            val_aa.append(test_AA)
            with open(save_path + '/' + filename_prefix + 'logger.txt', 'a') as file:
                file.write(
                    'Epoch [{}/{}], corr:{:.4f}, trace:{:.4f}, Loss: {:.4f}, Accuracy: {:.4f}, Val_Loss: {:.4f}, Val_trace: {:.4f}, Val_corr: {:.4f}, Val_Accuuracy: {:.4f}, Val_AA: {:.4f}, Learning_rate: {:.4e}, Epoch time: {}s, max_ao:{:.4f}, max_aa:{:.4f}, Bsave: {:.4f}\n'.format(
                        epoch, num_epochs, corr, tra, train_loss, train_accuracy, val_loss, val_tra, val_corr, test_AO,
                        test_AA, optimizer.param_groups[0]["lr"], spend1, max_ao, max_aa, bmatsaved))
    with open(save_path + '/' + filename_prefix + 'logger.txt', 'a') as file:
        rtrainlen = int(len(x1_train_set) / (1 + kaugmentation))
        file.write(
            'train_set: {}, test_set: {}, Proportion of training set:{:4f}%\n'.format(rtrainlen, y_test_set.shape[0],
                                                                                      rtrainlen * 100 / (
                                                                                              rtrainlen +
                                                                                              y_test_set.shape[0])))


    torch.save(model, save_path + '/' + filename_prefix + 'last.pt')
    return val_losslist, val_acc, val_aa, test_joint_layer_all

def train_mynetwork_LowMemory(x1_train_set, x2_train_set, lable_test_set, y_train_set, y_test_set,
                              HSIz, SARz, patchsize, k_maxL, pad_width, kaugmentation, filename_prefix, save_path,
                              SARpad, HSIpad, corewidth, kmethod, min_AA_Save=0.7, min_AO_Save=0.6, learning_rate=1e-4,
                              num_epochs=500, minibatch_size=32, seed=-1, is_scheduler=1, is_show=0, is_random=1):
    t0 = time.time()
    # Single GPU or CPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if seed == -1:
        seed = int(time.time())
    if is_random:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)



    # Define Loss Value
    train_losslist = []
    val_losslist = []
    train_acc = []
    val_acc = []
    val_aa = []
    max_ao = 0.
    max_aa = 0.
    max_aoplusaa = min_AO_Save + min_AA_Save
    # Instantiate Object
    model = HGRDecNetwork(HSIz, SARz, patchsize, k_maxL)
    model.to(device)

    # Define Loss Function and Optimizer
    criterion = MuCNN_Loss()

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    if is_scheduler:

        lr_s = optim.lr_scheduler.OneCycleLR(optimizer,max_lr=learning_rate, steps_per_epoch=int(num_epochs/4),
                                             epochs=5, anneal_strategy='cos')

    # Train and Evaluate Mode
    vallableminibatches = Get_mini_batches_standard(lable_test_set.T, y_test_set, minibatch_size)

    t1 = time.time()
    spend1 = t1 - t0
    print('Data minibatch Time: {}'.format(spend1))
    with open(save_path + '/' + filename_prefix + 'logger.txt', 'a') as file:
        file.write('Data minibatch Time: {}\n'.format(spend1))

    test_joint_layer_all = []
    for epoch in range(num_epochs + 1):
        t0 = time.time()

        model.train()

        train_correct = 0.
        train_loss_total = 0.0

        corr_all = 0.
        tra_all = 0.
        loss_val_all = 0.
        corr_val_all = 0.
        tra_val_all = 0.

        seed = seed + 1
        bmatsaved = 0
        minibatches = random_mini_batches_standardtwoModality(x1_train_set, x2_train_set, y_train_set, minibatch_size,
                                                              seed)

        for minibatch in minibatches:
            (batch_x1, batch_x2, batch_y) = minibatch
            batch_x1 = torch.from_numpy(batch_x1).to(device)
            batch_x2 = torch.from_numpy(batch_x2).to(device)
            batch_y = torch.from_numpy(batch_y).to(device)
            # Forward Propagation
            joint_layer, c1, c2, c3, c4, c5, c6, c7, c8, r1, r2 = model(batch_x1, batch_x2)
            loss, corr_train, tra_train = criterion(joint_layer, batch_y, c1, c2, c3, c4, c5, c6, c7, c8, r1, r2)

            train_loss_total += loss.item()
            corr_all += corr_train.item()
            tra_all += tra_train.item()


            train_correct = train_correct + calculate_acc(joint_layer, batch_y)

            # Backward Propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss = train_loss_total / len(minibatches)
        train_accuracy = train_correct / len(minibatches)
        corr = corr_all / len(minibatches)
        tra = tra_all / len(minibatches)


        if is_scheduler:
            lr_s.step()


        # Test Function
        model.eval()

        test_joint_layer_all = []
        for vallableminibatch in vallableminibatches:
            (batch_lable, batch_y) = vallableminibatch
            batch_x2,__, __ = Get_data_set(SARpad, batch_lable.T,pad_width,corewidth,kmethod)
            batch_x1,__, __ = Get_data_set(HSIpad, batch_lable.T,pad_width,corewidth,kmethod)
            batch_x1 = torch.from_numpy(batch_x1).to(device)
            batch_x2 = torch.from_numpy(batch_x2).to(device)
            batch_y = torch.from_numpy(batch_y).to(device)



            with torch.no_grad():
                test_joint_layer, c1, c2, c3, c4, c5, c6, c7, c8, r1, r2 = model(batch_x1, batch_x2)
                loss_val, corr_val, tra_val = criterion(test_joint_layer, batch_y, c1, c2, c3, c4, c5, c6, c7, c8, r1,
                                                        r2)
            test_joint_layer_all.extend(test_joint_layer.cpu().detach().numpy())
            loss_val_all += loss_val.item()
            corr_val_all += corr_val.item()
            tra_val_all += tra_val.item()


        val_loss = loss_val_all / len(vallableminibatches)
        val_corr = corr_val_all / len(vallableminibatches)
        val_tra = tra_val_all / len(vallableminibatches)

        acc_Dict = calculate_acc_Dict(test_joint_layer_all, y_test_set, k_maxL)
        test_AO = acc_Dict['AO']
        test_AA = acc_Dict['AA']

        if (((test_AO > max_ao or test_AA > max_aa) and (test_AO >= min_AO_Save or test_AA >= min_AA_Save)) or (
                test_AO + test_AA) > max_aoplusaa):
            bmatsaved = 1
            torch.save(model, save_path + '/best_AO' + str(test_AO) + '_AA' + str(test_AA) + '.pt')
            sio.savemat(save_path + '/feature_AO' + str(test_AO) + '_AA' + str(test_AA) + '.mat',
                        {'feature': test_joint_layer_all, 'acc_Dict': acc_Dict, 'pad_width': pad_width})
            if test_AO > max_ao:
                max_ao = test_AO
            if test_AA > max_aa:
                max_aa = test_AA
            if (test_AO + test_AA) > max_aoplusaa:
                max_aoplusaa = test_AO + test_AA

        t1 = time.time()
        spend1 = t1 - t0

        if (epoch + 1) % 1 == 0:
            print(
                'Epoch [{}/{}], corr:{:.4f}, trace:{:.4f}, Loss: {:.4f}, Accuracy: {:.4f}, Val_Loss: {:.4f}, Val_trace: {:.4f}, Val_corr: {:.4f}, Val_Accuuracy: {:.4f}, Val_AA: {:.4f}, Learning_rate: {:.4e}, Epoch time: {:.4f}s'.format(
                    epoch, num_epochs, corr, tra, train_loss, train_accuracy, val_loss, val_tra, val_corr, test_AO,
                    test_AA, optimizer.param_groups[0]["lr"], spend1))
            print('max_ao:{:.4f}, max_aa:{:.4f}, Bsave: {}'.format(max_ao, max_aa, bmatsaved))
            train_losslist.append(train_loss)
            train_acc.append(train_accuracy)
            val_losslist.append(val_loss)
            val_acc.append(test_AO)
            val_aa.append(test_AA)
            with open(save_path + '/' + filename_prefix + 'logger.txt', 'a') as file:
                file.write(
                    'Epoch [{}/{}], corr:{:.4f}, trace:{:.4f}, Loss: {:.4f}, Accuracy: {:.4f}, Val_Loss: {:.4f}, Val_trace: {:.4f}, Val_corr: {:.4f}, Val_Accuuracy: {:.4f}, Val_AA: {:.4f}, Learning_rate: {:.4e}, Epoch time: {}s, max_ao:{:.4f}, max_aa:{:.4f}, Bsave: {:.4f}\n'.format(
                        epoch, num_epochs, corr, tra, train_loss, train_accuracy, val_loss, val_tra, val_corr, test_AO,
                        test_AA, optimizer.param_groups[0]["lr"], spend1, max_ao, max_aa, bmatsaved))
    with open(save_path + '/' + filename_prefix + 'logger.txt', 'a') as file:
        rtrainlen = int(len(x1_train_set) / (1 + kaugmentation))
        file.write(
            'train_set: {}, test_set: {}, Proportion of training set:{:4f}%\n'.format(rtrainlen, y_test_set.shape[0],
                                                                                      rtrainlen * 100 / (
                                                                                              rtrainlen +
                                                                                              y_test_set.shape[0])))


    torch.save(model, save_path + '/' + filename_prefix + 'last.pt')
    return val_losslist, val_acc, val_aa, test_joint_layer_all
