"""
Object recognition Things-EEG dataset

use 250 Hz data
"""

import os
import argparse
import random
import itertools
import datetime
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
from torch.autograd import Variable
from einops.layers.torch import Rearrange
from models.ViEEG import ViEEG

# gpus = [0]
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
# model_id = 'ablution'
# train_type = 'dependent'
# if train_type == 'dependent':
#     epoch_num = 100
#     val_num = 740
# elif train_type == 'independent':
#     epoch_num = 50
#     val_num = 6660
# fold_num = 5
# n_layer = 2
# n_head = 1
gpus = [2]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
model_id = 'EyeEEG'
train_type = 'dependent'
if train_type == 'dependent':
    epoch_num = 80
    val_num = 740
elif train_type == 'independent':
    epoch_num = 40
    val_num = 6660
fold_num = 1
n_layer = 1
n_head = 1
trained_model_path = os.path.join('/EyeEEG/trained_models', train_type, model_id)
results_path = os.path.join('/EyeEEG/results/', train_type, model_id)

parser = argparse.ArgumentParser(description='Experiment Stimuli Recognition test with CLIP encoder')
parser.add_argument('--epoch', default=epoch_num, type=int)
parser.add_argument('--n_layer', default=n_layer, type=int)
parser.add_argument('--n_head', default=n_head, type=int)
parser.add_argument('--fold', default=fold_num, type=int)
parser.add_argument('--train_type', default=train_type, type=str)
parser.add_argument('--clip_dir', default='CLIP-VIT-H-14', type=str)
parser.add_argument('--eeg_data_path', default='/Data/Things_EEG2/Preprocessed_data_250Hz',
                    type=str)
parser.add_argument('--img_data_path', default='/EyeEEG/data', type=str)
parser.add_argument('--trained_model_path', default=trained_model_path, type=str)
parser.add_argument('--results_path', default=results_path, type=str)
parser.add_argument('--num_sub', default=10, type=int, help='number of subjects used in the experiments. ')
parser.add_argument('-batch_size', '--batch-size', default=1000, type=int, metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', default=2e-4, type=float)
parser.add_argument('--seed', default=2023, type=int, help='seed for initializing training. ')


class IE():
    def __init__(self, args, nsub):
        super(IE, self).__init__()
        self.args = args
        self.num_class = 200
        self.batch_size = args.batch_size
        self.batch_size_test = 400
        self.batch_size_img = 500
        self.n_epochs = args.epoch
        self.lambda_cen = 0.003
        self.alpha = 0.5
        self.proj_dim = 256
        self.lr = args.lr
        self.b1 = 0.5
        self.b2 = 0.999
        self.nSub = nsub
        self.start_epoch = 0
        self.test_center_path = '/data/'
        self.pretrain = False
        self.eeg_data_path = args.eeg_data_path
        self.img_data_path = args.img_data_path
        self.clip_dir = args.clip_dir
        self.log_write = open(args.results_path + "/log_subject%d.txt" % self.nSub, "w")

        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor

        self.criterion_l1 = torch.nn.L1Loss().cuda()
        self.criterion_l2 = torch.nn.MSELoss().cuda()
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
        self.model = EyeEEG().cuda()
        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.centers = {}
        print('initial define done.')

    def get_eeg_data(self):
        train_data = []
        train_label = []
        test_data = []
        test_label = np.arange(200)

        train_data = np.load(self.eeg_data_path + '/sub-' + format(self.nSub, '02') + '/preprocessed_eeg_training.npy',
                             allow_pickle=True)
        train_data = train_data['preprocessed_eeg_data']
        train_data = np.mean(train_data, axis=1)
        train_data = np.expand_dims(train_data, axis=1)

        test_data = np.load(self.eeg_data_path + '/sub-' + format(self.nSub, '02') + '/preprocessed_eeg_test.npy',
                            allow_pickle=True)
        test_data = test_data['preprocessed_eeg_data']
        test_data = np.mean(test_data, axis=1)
        test_data = np.expand_dims(test_data, axis=1)

        return train_data, train_label, test_data, test_label

    def get_image_data(self):
        train_img_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_image_train.npy')
        test_img_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_image_test.npy')
        train_label_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_label_train.npy')
        test_label_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_label_test.npy')

        train_mask_img_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_mask_image_train.npy')
        test_mask_img_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_mask_image_test.npy')
        train_mask01_img_path = os.path.join(self.img_data_path, self.clip_dir,
                                             'clip_feature_maps_mask01_image_train.npy')
        test_mask01_img_path = os.path.join(self.img_data_path, self.clip_dir,
                                            'clip_feature_maps_mask01_image_test.npy')

        train_img_feature = np.load(train_img_path, allow_pickle=True)
        test_img_feature = np.load(test_img_path, allow_pickle=True)
        train_label_feature = np.load(train_label_path, allow_pickle=True)
        test_label_feature = np.load(test_label_path, allow_pickle=True)
        train_mask_img_feature = np.load(train_mask_img_path, allow_pickle=True)
        test_mask_img_feature = np.load(test_mask_img_path, allow_pickle=True)
        train_mask01_img_feature = np.load(train_mask01_img_path, allow_pickle=True)
        test_mask01_img_feature = np.load(test_mask01_img_path, allow_pickle=True)

        train_img_feature = np.squeeze(train_img_feature)
        test_img_feature = np.squeeze(test_img_feature)
        train_label_feature = np.squeeze(train_label_feature)
        test_label_feature = np.squeeze(test_label_feature)
        train_mask_img_feature = np.squeeze(train_mask_img_feature)
        test_mask_img_feature = np.squeeze(test_mask_img_feature)
        train_mask01_img_feature = np.squeeze(train_mask01_img_feature)
        test_mask01_img_feature = np.squeeze(test_mask01_img_feature)

        return [train_img_feature, test_img_feature], [train_label_feature, test_label_feature], \
            [train_mask_img_feature, test_mask_img_feature], [train_mask01_img_feature, test_mask01_img_feature]

    def get_eeg_data_cross_subject(self):
        print('Begin reading...')
        train_data = []
        train_label = []
        test_data = []
        test_label = []

        for i in range(10):
            sub = i + 1
            print('Sub-' + str(sub) + ' Loading...')
            if sub == self.nSub:
                test_data_sub = np.load(
                    os.path.join(self.eeg_data_path, 'sub-' + format(sub, '02'), 'preprocessed_eeg_test.npy'),
                    allow_pickle=True)
                test_data_sub = test_data_sub['preprocessed_eeg_data']
                test_data_sub = np.mean(test_data_sub, axis=1)
                test_data_sub = np.expand_dims(test_data_sub, axis=1)
                test_data.append(test_data_sub)
                test_label.append(np.full(test_data_sub.shape[0], sub))  # 标签为当前被试的编号

                test_data_sub_train = np.load(
                    os.path.join(self.eeg_data_path, 'sub-' + format(sub, '02'), 'preprocessed_eeg_training.npy'),
                    allow_pickle=True)
                test_data_sub_train = test_data_sub_train['preprocessed_eeg_data']
                test_data_sub_train = np.mean(test_data_sub_train, axis=1)
                test_data_sub_train = np.expand_dims(test_data_sub_train, axis=1)
                test_data.append(test_data_sub_train)
                test_label.append(np.full(test_data_sub_train.shape[0], sub))  # 标签为当前被试的编号

            else:
                train_data_sub = np.load(
                    os.path.join(self.eeg_data_path, 'sub-' + format(sub, '02'), 'preprocessed_eeg_training.npy'),
                    allow_pickle=True)
                train_data_sub = train_data_sub['preprocessed_eeg_data']
                train_data_sub = np.mean(train_data_sub, axis=1)
                train_data_sub = np.expand_dims(train_data_sub, axis=1)
                train_data.append(train_data_sub)
                train_label.append(np.full(train_data_sub.shape[0], sub))  # 标签为当前被试的编号

                train_data_sub_test = np.load(
                    os.path.join(self.eeg_data_path, 'sub-' + format(sub, '02'), 'preprocessed_eeg_test.npy'),
                    allow_pickle=True)
                train_data_sub_test = train_data_sub_test['preprocessed_eeg_data']
                train_data_sub_test = np.mean(train_data_sub_test, axis=1)
                train_data_sub_test = np.expand_dims(train_data_sub_test, axis=1)
                train_data.append(train_data_sub_test)
                train_label.append(np.full(train_data_sub_test.shape[0], sub))  # 标签为当前被试的编号

        train_data = np.concatenate(train_data, axis=0)
        train_label = np.concatenate(train_label, axis=0)
        test_data = np.concatenate(test_data, axis=0)
        test_label = np.concatenate(test_label, axis=0)

        return train_data, train_label, test_data, test_label

    def get_image_data_cross_subject(self):
        train_img_feature = []
        train_label_feature = []
        train_mask_feature = []
        train_mask01_feature = []

        train_img_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_image_train.npy')
        test_img_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_image_test.npy')
        train_label_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_label_train.npy')
        test_label_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_label_test.npy')
        train_mask_img_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_mask_image_train.npy')
        test_mask_img_path = os.path.join(self.img_data_path, self.clip_dir, 'clip_feature_maps_mask_image_test.npy')
        train_mask01_img_path = os.path.join(self.img_data_path, self.clip_dir,
                                             'clip_feature_maps_mask01_image_train.npy')
        test_mask01_img_path = os.path.join(self.img_data_path, self.clip_dir,
                                            'clip_feature_maps_mask01_image_test.npy')

        train_img_feature_sub = np.load(train_img_path, allow_pickle=True)
        test_img_feature_sub = np.load(test_img_path, allow_pickle=True)
        train_label_feature_sub = np.load(train_label_path, allow_pickle=True)
        test_label_feature_sub = np.load(test_label_path, allow_pickle=True)
        train_mask_img_feature_sub = np.load(train_mask_img_path, allow_pickle=True)
        test_mask_img_feature_sub = np.load(test_mask_img_path, allow_pickle=True)
        train_mask01_img_feature_sub = np.load(train_mask01_img_path, allow_pickle=True)
        test_mask01_img_feature_sub = np.load(test_mask01_img_path, allow_pickle=True)

        img_feature_sub = np.concatenate([train_img_feature_sub, test_img_feature_sub], axis=0)
        label_feature_sub = np.concatenate([train_label_feature_sub, test_label_feature_sub], axis=0)
        mask_feature_sub = np.concatenate([train_mask_img_feature_sub, test_mask_img_feature_sub], axis=0)
        mask01_feature_sub = np.concatenate([train_mask01_img_feature_sub, test_mask01_img_feature_sub], axis=0)

        for i in range(9):
            train_img_feature.append(img_feature_sub)
            train_label_feature.append(label_feature_sub)
            train_mask_feature.append(mask_feature_sub)
            train_mask01_feature.append(mask01_feature_sub)
        train_img_feature = np.concatenate(train_img_feature, axis=0)
        train_label_feature = np.concatenate(train_label_feature, axis=0)
        train_mask_feature = np.concatenate(train_mask_feature, axis=0)
        train_mask01_feature = np.concatenate(train_mask01_feature, axis=0)

        test_img_feature = test_img_feature_sub
        test_label_feature = test_label_feature_sub
        test_mask_img_feature = test_mask_img_feature_sub
        test_mask01_img_feature = test_mask01_img_feature_sub

        return [train_img_feature, test_img_feature], [train_label_feature, test_label_feature], \
            [train_mask_feature, test_mask_img_feature], [train_mask01_feature, test_mask01_img_feature]

    def update_lr(self, optimizer, lr):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def train(self):
        self.model = ViEEG(n_layer=n_layer, n_head=n_head).cuda()

        if self.args.train_type == 'dependent':
            train_eeg, _, test_eeg, test_label = self.get_eeg_data()
        else:
            train_eeg, _, test_eeg, test_label = self.get_eeg_data_cross_subject()
            _, _, test_eeg, test_label = self.get_eeg_data()

        print("Train EEG Size:", train_eeg.shape)
        print("Test EEG Size:", test_eeg.shape)
        print("Test label Size:", test_label.shape)
        [train_image_feature, test_image_feature], [train_label_feature, test_label_feature], \
            [train_mask_image_feature, test_mask_image_feature], [train_mask01_image_feature, test_mask01_image_feature] \
            = self.get_image_data()
        if self.args.train_type == 'dependent':
            [train_image_feature, test_image_feature], [train_label_feature, test_label_feature], \
                [train_mask_image_feature, test_mask_image_feature], [train_mask01_image_feature,
                                                                      test_mask01_image_feature] \
                = self.get_image_data()
        else:
            [train_image_feature, test_image_feature], [train_label_feature, test_label_feature], \
                [train_mask_image_feature, test_mask_image_feature], [train_mask01_image_feature,
                                                                      test_mask01_image_feature] \
                = self.get_image_data_cross_subject()
        print("Train Image Size:", train_image_feature.shape)
        print("Test Image Size:", test_image_feature.shape)
        print("Train Mask Image Size:", train_mask_image_feature.shape)
        print("Test Mask Image Size:", test_mask_image_feature.shape)
        print("Train Mask01 Image Size:", train_mask01_image_feature.shape)
        print("Test Mask01 Image Size:", test_mask01_image_feature.shape)
        print("Train Label Size:", train_label_feature.shape)
        print("Test Label Size:", test_label_feature.shape)

        # shuffle the training data
        train_shuffle = np.random.permutation(len(train_eeg))
        train_eeg = train_eeg[train_shuffle]
        train_image_feature = train_image_feature[train_shuffle]
        train_mask_image_feature = train_mask_image_feature[train_shuffle]
        train_mask01_image_feature = train_mask01_image_feature[train_shuffle]
        train_label_feature = train_label_feature[train_shuffle]

        val_eeg = torch.from_numpy(train_eeg[:val_num])
        val_image = torch.from_numpy(train_image_feature[:val_num])
        val_mask_image = torch.from_numpy(train_mask_image_feature[:val_num])
        val_mask01_image = torch.from_numpy(train_mask01_image_feature[:val_num])
        val_label = torch.from_numpy(train_label_feature[:val_num])

        train_eeg = torch.from_numpy(train_eeg[val_num:])
        train_image = torch.from_numpy(train_image_feature[val_num:])
        train_mask_image = torch.from_numpy(train_mask_image_feature[val_num:])
        train_mask01_image = torch.from_numpy(train_mask01_image_feature[val_num:])
        train_label = torch.from_numpy(train_label_feature[val_num:])

        test_eeg = torch.from_numpy(test_eeg)
        test_image = torch.from_numpy(test_image_feature)
        test_mask_image = torch.from_numpy(test_mask_image_feature)
        test_mask01_image = torch.from_numpy(test_mask01_image_feature)
        test_label_ce = torch.from_numpy(test_label)
        test_label = torch.from_numpy(test_label_feature)

        dataset = torch.utils.data.TensorDataset(train_eeg, train_image, train_mask_image, train_mask01_image, train_label)
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size,
                                                      shuffle=True)
        val_dataset = torch.utils.data.TensorDataset(val_eeg, val_image, val_mask_image, val_mask01_image, val_label)
        self.val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=self.batch_size,
                                                          shuffle=False)
        test_dataset = torch.utils.data.TensorDataset(test_eeg, test_image, test_mask_image, test_mask01_image, test_label, test_label_ce)
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size_test,
                                                          shuffle=False)

        # self.val_dataloader = self.test_dataloader

        # test_eeg = torch.from_numpy(test_eeg)
        # # test_img_feature = torch.from_numpy(test_img_feature)
        # test_image_center = torch.from_numpy(test_image_center)
        # test_label_center = torch.from_numpy(test_label_center)
        # test_label = torch.from_numpy(test_label)
        # test_dataset = torch.utils.data.TensorDataset(test_eeg, test_label)
        # self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size_test,
        #                                                    shuffle=False)

        # Optimizers
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        num = 0
        best_loss_val = np.inf

        for e in range(self.n_epochs):
            in_epoch = time.time()
            self.model.train()
            # starttime_epoch = datetime.datetime.now()

            for i, (eeg, img, img_m, img_01, lab) in enumerate(self.dataloader):

                eeg = Variable(eeg.cuda().type(self.Tensor))
                img_features = Variable(img.cuda().type(self.Tensor))
                img_m_features = Variable(img_m.cuda().type(self.Tensor))
                img_01_features = Variable(img_01.cuda().type(self.Tensor))
                lab_features = Variable(lab.cuda().type(self.Tensor))
                labels = torch.arange(eeg.shape[0])  # used for the loss
                labels = Variable(labels.cuda().type(self.LongTensor))

                # obtain the features
                starttime = datetime.datetime.now()
                v_eeg_features, m_eeg_features, b_eeg_features = self.model(eeg)

                # normalize the features
                v_eeg_features = v_eeg_features / v_eeg_features.norm(dim=1, keepdim=True)
                m_eeg_features = m_eeg_features / m_eeg_features.norm(dim=1, keepdim=True)
                b_eeg_features = b_eeg_features / b_eeg_features.norm(dim=1, keepdim=True)
                img_features = img_features / img_features.norm(dim=1, keepdim=True)
                img_m_features = img_m_features / img_m_features.norm(dim=1, keepdim=True)
                img_01_features = img_01_features / img_01_features.norm(dim=1, keepdim=True)

                # cosine similarity as the logits
                logit_scale = self.logit_scale.exp()
                eeg_features = torch.cat((v_eeg_features, m_eeg_features, b_eeg_features), dim=1)
                img_features = torch.cat((img_features, img_m_features, img_01_features), dim=1)
                v_logits_per_eeg = logit_scale * eeg_features @ img_features.t()

                loss_cos = self.criterion_cls(v_logits_per_eeg, labels)
                loss = loss_cos

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                endtime = datetime.datetime.now()
                print('1:', str(endtime - starttime))

            if (e + 1) % 1 == 0:
                self.model.eval()
                with torch.no_grad():
                    # * validation part
                    for i, (veeg, vimg, vimg_m, vimg_01, vlab) in enumerate(self.val_dataloader):

                        veeg = Variable(veeg.cuda().type(self.Tensor))
                        vimg_features = Variable(vimg.cuda().type(self.Tensor))
                        vimg_m_features = Variable(vimg_m.cuda().type(self.Tensor))
                        vimg_01_features = Variable(vimg_01.cuda().type(self.Tensor))
                        vlab_features = Variable(vlab.cuda().type(self.Tensor))
                        vlabels = torch.arange(veeg.shape[0])
                        vlabels = Variable(vlabels.cuda().type(self.LongTensor))

                        starttime = datetime.datetime.now()
                        v_veeg_features, m_veeg_features, b_veeg_features = self.model(veeg)
                        v_veeg_features = v_veeg_features / v_veeg_features.norm(dim=1, keepdim=True)
                        m_veeg_features = m_veeg_features / m_veeg_features.norm(dim=1, keepdim=True)
                        b_veeg_features = b_veeg_features / b_veeg_features.norm(dim=1, keepdim=True)
                        vimg_features = vimg_features / vimg_features.norm(dim=1, keepdim=True)
                        vimg_m_features = vimg_m_features / vimg_m_features.norm(dim=1, keepdim=True)
                        vimg_01_features = vimg_01_features / vimg_01_features.norm(dim=1, keepdim=True)

                        logit_scale = self.logit_scale.exp()
                        veeg_features = torch.cat((v_veeg_features, m_veeg_features, b_veeg_features), dim=1)
                        vimg_features = torch.cat((vimg_features, vimg_m_features, vimg_01_features), dim=1)
                        v_vlogits_per_eeg = logit_scale * veeg_features @ vimg_features.t()

                        vloss_cos = self.criterion_cls(v_vlogits_per_eeg, vlabels)
                        vloss = vloss_cos

                        endtime = datetime.datetime.now()
                        print('2:', str(endtime - starttime))

                        if vloss <= best_loss_val:
                            best_loss_val = vloss
                            best_epoch = e + 1
                            torch.save(self.model.state_dict(), self.args.trained_model_path + '/model_' + str(self.nSub) + '.pth')

                print('Epoch:', e,
                      '  loss visual train: %.4f' % loss_cos.detach().cpu().numpy(),
                      '  loss visual val: %.4f' % vloss_cos.detach().cpu().numpy(),
                      )
                self.log_write.write('Epoch %d: loss Train: %.4f, loss val: %.4f\n' % (
                    e, loss_cos.detach().cpu().numpy(), vloss_cos.detach().cpu().numpy()))

        # * test part
        total = 0
        top1 = 0
        top3 = 0
        top5 = 0
        top1_ablution = [0, 0, 0, 0, 0, 0]
        top3_ablution = [0, 0, 0, 0, 0, 0]
        top5_ablution = [0, 0, 0, 0, 0, 0]

        self.model.load_state_dict(torch.load(self.args.trained_model_path + '/model_' + str(self.nSub) + '.pth'), strict=False)
        self.model.eval()

        with torch.no_grad():
            GT_list = []
            v_similarity_list = []
            for i, (teeg, timg, timg_m, timg_01, tlab, tlab_ce) in enumerate(self.test_dataloader):
                teeg = Variable(teeg.cuda().type(self.Tensor))
                timg_features = Variable(timg.cuda().type(self.Tensor))
                timg_m_features = Variable(timg_m.cuda().type(self.Tensor))
                timg_01_features = Variable(timg_01.cuda().type(self.Tensor))
                tlab_features = Variable(vlab.cuda().type(self.Tensor))
                tlabels = Variable(tlab_ce.cuda().type(self.LongTensor))

                v_teeg_features, m_teeg_features, b_teeg_features = self.model(teeg)
                v_teeg_features = v_teeg_features / v_teeg_features.norm(dim=1, keepdim=True)
                m_teeg_features = m_teeg_features / m_teeg_features.norm(dim=1, keepdim=True)
                b_teeg_features = b_teeg_features / b_teeg_features.norm(dim=1, keepdim=True)
                timg_features = timg_features / timg_features.norm(dim=1, keepdim=True)
                timg_m_features = timg_m_features / timg_m_features.norm(dim=1, keepdim=True)
                timg_01_features = timg_01_features / timg_01_features.norm(dim=1, keepdim=True)

                logit_scale = self.logit_scale.exp()
                eeg_features = torch.cat((v_teeg_features, m_teeg_features, b_teeg_features), dim=1)
                img_features = torch.cat((timg_features, timg_m_features, timg_01_features), dim=1)
                v_tlogits_per_eeg = eeg_features @ img_features.t()
                similarity = (100.0 * v_tlogits_per_eeg).softmax(dim=-1)

                _, v_indices = similarity.topk(5)

                similarity_np = v_tlogits_per_eeg.cpu().numpy()
                path_save = os.path.join(self.args.results_path, 'similarity_sub' + str(self.nSub) + '.npy')
                np.save(path_save, similarity_np)

                _, indices10 = similarity.topk(10)
                indices10_np = indices10.cpu().numpy()
                path_save = os.path.join(self.args.results_path, 'indices10_sub'+str(self.nSub)+'.npy')
                np.save(path_save, indices10_np)

                tt_label = tlabels.view(-1, 1)
                total += tlabels.size(0)
                top1 += (tt_label == v_indices[:, :1]).sum().item()
                top3 += (tt_label == v_indices[:, :3]).sum().item()
                top5 += (tt_label == v_indices).sum().item()

                # ablution v_teeg_features, m_teeg_features, b_teeg_features
                top1_ablution, top3_ablution, top5_ablution = comput_ablution_acc(
                    top1_ablution, top3_ablution, top5_ablution, tt_label,
                    v_teeg_features, m_teeg_features, b_teeg_features,
                    timg_features, timg_m_features, timg_01_features)

                # Collect averages
                GT_list.append(tlabels.cpu())
                v_similarity_list.append(similarity.cpu())

            top1_acc = float(top1) / float(total)
            top3_acc = float(top3) / float(total)
            top5_acc = float(top5) / float(total)
            top1_ab, top3_ab, top5_ab = comput_ablution_acc2(top1_ablution, top3_ablution, top5_ablution, total)

            # Calculate overall averages for the subject
            GT_list = torch.stack(GT_list)
            print(GT_list.shape)
            v_similarity_list = torch.stack(v_similarity_list)
            print(v_similarity_list.shape)

        print('The test overall Top1-%.6f, Top3-%.6f, Top5-%.6f' % (top1_acc, top3_acc, top5_acc))
        print('The test visual Top1-%.6f, Top3-%.6f, Top5-%.6f' % (top1_ab[0], top3_ab[0], top5_ab[0]))
        print('The test mask Top1-%.6f, Top3-%.6f, Top5-%.6f' % (top1_ab[1], top3_ab[1], top5_ab[1]))
        print('The test mask01 Top1-%.6f, Top3-%.6f, Top5-%.6f' % (top1_ab[2], top3_ab[2], top5_ab[2]))
        print('The test 12 Top1-%.6f, Top3-%.6f, Top5-%.6f' % (top1_ab[3], top3_ab[3], top5_ab[3]))
        print('The test 13 Top1-%.6f, Top3-%.6f, Top5-%.6f' % (top1_ab[4], top3_ab[4], top5_ab[4]))
        print('The test 23 Top1-%.6f, Top3-%.6f, Top5-%.6f' % (top1_ab[5], top3_ab[5], top5_ab[5]))
        self.log_write.write('The best epoch is: %d\n' % best_epoch)
        self.log_write.write('The test Top1-%.6f, Top3-%.6f, Top5-%.6f\n' % (top1_acc, top3_acc, top5_acc))

        return top1_acc, top3_acc, top5_acc
        # writer.close()


def comput_ablution_acc(top1_ablution, top3_ablution, top5_ablution, tt_label,
                        v_teeg_features, m_teeg_features, b_teeg_features,
                        timg_features, timg_m_features, timg_01_features):
    similarity_visual = (100.0 * v_teeg_features @ timg_features.t()).softmax(dim=-1)
    similarity_mask = (100.0 * m_teeg_features @ timg_m_features.t()).softmax(dim=-1)
    similarity_mask01 = (100.0 * b_teeg_features @ timg_01_features.t()).softmax(dim=-1)

    eeg_features12 = torch.cat((v_teeg_features, m_teeg_features), dim=1)
    img_features12 = torch.cat((timg_features, timg_m_features), dim=1)
    similarity_12 = (100.0 * eeg_features12 @ img_features12.t()).softmax(dim=-1)

    eeg_features13 = torch.cat((v_teeg_features, b_teeg_features), dim=1)
    img_features13 = torch.cat((timg_features, timg_01_features), dim=1)
    similarity_13 = (100.0 * eeg_features13 @ img_features13.t()).softmax(dim=-1)

    eeg_features23 = torch.cat((m_teeg_features, b_teeg_features), dim=1)
    img_features23 = torch.cat((timg_m_features, timg_01_features), dim=1)
    similarity_23 = (100.0 * eeg_features23 @ img_features23.t()).softmax(dim=-1)

    _, indices_v = similarity_visual.topk(5)
    _, indices_m = similarity_mask.topk(5)
    _, indices_01 = similarity_mask01.topk(5)
    _, indices_12 = similarity_12.topk(5)
    _, indices_13 = similarity_13.topk(5)
    _, indices_23 = similarity_23.topk(5)

    top1_ablution[0] += (tt_label == indices_v[:, :1]).sum().item()
    top3_ablution[0] += (tt_label == indices_v[:, :3]).sum().item()
    top5_ablution[0] += (tt_label == indices_v).sum().item()
    top1_ablution[1] += (tt_label == indices_m[:, :1]).sum().item()
    top3_ablution[1] += (tt_label == indices_m[:, :3]).sum().item()
    top5_ablution[1] += (tt_label == indices_m).sum().item()
    top1_ablution[2] += (tt_label == indices_01[:, :1]).sum().item()
    top3_ablution[2] += (tt_label == indices_01[:, :3]).sum().item()
    top5_ablution[2] += (tt_label == indices_01).sum().item()
    top1_ablution[3] += (tt_label == indices_12[:, :1]).sum().item()
    top3_ablution[3] += (tt_label == indices_12[:, :3]).sum().item()
    top5_ablution[3] += (tt_label == indices_12).sum().item()
    top1_ablution[4] += (tt_label == indices_13[:, :1]).sum().item()
    top3_ablution[4] += (tt_label == indices_13[:, :3]).sum().item()
    top5_ablution[4] += (tt_label == indices_13).sum().item()
    top1_ablution[5] += (tt_label == indices_23[:, :1]).sum().item()
    top3_ablution[5] += (tt_label == indices_23[:, :3]).sum().item()
    top5_ablution[5] += (tt_label == indices_23).sum().item()
    return top1_ablution, top3_ablution, top5_ablution


def comput_ablution_acc2(top1_ablution, top3_ablution, top5_ablution, total):
    top1_ab = [0, 0, 0, 0, 0, 0]
    top3_ab = [0, 0, 0, 0, 0, 0]
    top5_ab = [0, 0, 0, 0, 0, 0]
    for i in range(6):
        top1_ab[i] = float(top1_ablution[i]) / float(total)
        top3_ab[i] = float(top3_ablution[i]) / float(total)
        top5_ab[i] = float(top5_ablution[i]) / float(total)
    return top1_ab, top3_ab, top5_ab


def main():
    args = parser.parse_args()
    print('\nInput arguments:')
    for key, val in vars(args).items():
        print('{:16} {}'.format(key, val))

    num_sub = args.num_sub
    cal_num = 0
    aver = []
    aver3 = []
    aver5 = []
    all_GT = []
    all_f_similarity = []
    all_v_similarity = []
    all_s_similarity = []
    all_visual_binary_sum = []
    all_semantic_binary_sum = []
    all_visual_binary_avg = []
    all_semantic_binary_avg = []
    all_visualBrain = []
    all_semanticBrain = []

    for i in range(num_sub):
        cal_num += 1
        starttime = datetime.datetime.now()
        seed_n = np.random.randint(args.seed)

        print('seed is ' + str(seed_n))
        random.seed(seed_n)
        np.random.seed(seed_n)
        torch.manual_seed(seed_n)
        torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)

        print('Subject %d' % (i + 1))
        ie = IE(args, i + 1)

        Acc, Acc3, Acc5 = ie.train()
        print('THE BEST ACCURACY IS ' + str(Acc))

        endtime = datetime.datetime.now()
        print('subject %d duration: ' % (i + 1) + str(endtime - starttime))

        # all_GT.append(GT_list)
        # all_f_similarity.append(f_similarity_list)
        # all_v_similarity.append(v_similarity_list)
        # all_s_similarity.append(s_similarity_list)
        # all_visual_binary_sum.append(visual_binary_sum)  # shape: [NUM_SUBJECTS, ...]
        # all_semantic_binary_sum.append(semantic_binary_sum)
        # all_visual_binary_avg.append(visual_binary_avg)
        # all_semantic_binary_avg.append(semantic_binary_avg)
        # all_visualBrain.append(visualBrain)
        # all_semanticBrain.append(semanticlBrain)
        aver.append(Acc)
        aver3.append(Acc3)
        aver5.append(Acc5)

    aver.append(np.mean(aver))
    aver3.append(np.mean(aver3))
    aver5.append(np.mean(aver5))

    column = np.arange(1, cal_num + 1).tolist()
    column.append('ave')
    pd_all = pd.DataFrame(columns=column, data=[aver, aver3, aver5])
    pd_all.to_csv(args.results_path + '/result.csv')

    # all_GT = np.stack(all_GT)
    # all_f_similarity = np.stack(all_f_similarity)
    # all_v_similarity = np.stack(all_v_similarity)
    # all_s_similarity = np.stack(all_s_similarity)
    # all_visual_binary_sum = np.stack(all_visual_binary_sum)
    # all_semantic_binary_sum = np.stack(all_semantic_binary_sum)
    # all_visual_binary_avg = np.stack(all_visual_binary_avg)
    # all_semantic_binary_avg = np.stack(all_semantic_binary_avg)
    # all_visualBrain = np.stack(all_visualBrain)
    # all_semanticBrain = np.stack(all_semanticBrain)
    #
    # np.savez(
    #     "all_subjects_res.npz",
    #     all_GT=all_GT,
    #     all_f_similarity=all_f_similarity,
    #     all_v_similarity=all_v_similarity,
    #     all_s_similarity=all_s_similarity,
    #     all_visual_binary_sum=all_visual_binary_sum,
    #     all_semantic_binary_sum=all_semantic_binary_sum,
    #     all_visual_binary_avg=all_visual_binary_avg,
    #     all_semantic_binary_avg=all_semantic_binary_avg,
    #     all_visualBrain=all_visualBrain,
    #     all_semanticBrain=all_semanticBrain
    # )


if __name__ == "__main__":
    for f in range(fold_num):
        print(time.asctime(time.localtime(time.time())))
        main()
        print(time.asctime(time.localtime(time.time())))
        print('-' * 100)
