import gc
import random
import sys
from getopt import getopt

import torch
from sklearn.metrics import average_precision_score, f1_score, roc_auc_score
from torch.utils.data import DataLoader
from model.EventComprehensionNet import EventComprehensionNet
from load_data import LoadData
from metrics.cm_and_f1 import plot_and_print_cm, plot_confusion_matrix, f1_scores_from_cm
from metrics.event_level_f1 import cal_multiclass_event_level_f1
import time
import numpy as np
import torch.nn.functional as F
import os

par_dict = {}

# Fixed parameters
if "seed" not in par_dict.keys():
	par_dict["seed"] = 43
if "kfold_num" not in par_dict.keys():
	par_dict["kfold_num"] = 5
if "window_size" not in par_dict.keys():
	par_dict["window_size"] = 320
if "batch_size" not in par_dict.keys():
	par_dict["batch_size"] = 512
if "num_epochs" not in par_dict.keys():
	par_dict["num_epochs"] = 50
if "stop_patience" not in par_dict.keys():
	par_dict["stop_patience"] = 20
if "learning_rate" not in par_dict.keys():
	par_dict["learning_rate"] = 0.001
if "weight_decay" not in par_dict.keys():
	par_dict["weight_decay"] = 0.0002
if "dropout_rate" not in par_dict.keys():
	par_dict["dropout_rate"] = 0.0
if "beta_1" not in par_dict.keys():
	par_dict["beta_1"] = 0.9
if "beta_2" not in par_dict.keys():
	par_dict["beta_2"] = 0.999
if "let_go_rate" not in par_dict.keys():
	par_dict["let_go_rate"] = 0.4
if "convchannels" not in par_dict.keys():
	par_dict["convchannels"] = 256
if "kernel_size" not in par_dict.keys():
	par_dict["kernel_size"] = 13
if "sampling_scale" not in par_dict.keys():
	par_dict["sampling_scale"] = [2, 2, 2, 2]
if "num_class" not in par_dict.keys():
	par_dict["num_class"] = 4
if "multi_task_weight_0" not in par_dict.keys():
	par_dict["multi_task_weight_0"] = 1.0
if "multi_task_weight_1" not in par_dict.keys():
	par_dict["multi_task_weight_1"] = 0.6
if "multi_task_weight_2" not in par_dict.keys():
	par_dict["multi_task_weight_2"] = 0.4
if "multi_task_weight_3" not in par_dict.keys():
	par_dict["multi_task_weight_3"] = 0.8
if "weight_0" not in par_dict.keys():
	par_dict["weight_0"] = 1.0
if "weight_1" not in par_dict.keys():
	par_dict["weight_1"] = 1.0
if "weight_2" not in par_dict.keys():
	par_dict["weight_2"] = 2.0
if "weight_3" not in par_dict.keys():
	par_dict["weight_3"] = 2.5
if "mask_weight" not in par_dict.keys():
	par_dict["mask_weight"] = torch.tensor(0.5)
else:
	par_dict["mask_weight"] = torch.tensor(par_dict["mask_weight"])
if "avg_window" not in par_dict.keys():
	par_dict["avg_window"] = 320
if "train_val_cutrate" not in par_dict.keys():
	par_dict["train_val_cutrate"] = 0.8
if "data_dir" not in par_dict.keys():
	par_dict["data_dir"] = "/path/data"
if "save_dir" not in par_dict.keys():
	par_dict["save_dir"] = "/path/save_files"
if "save_model_name" not in par_dict.keys():
	par_dict["save_model_name"] = "model.pkl"

# Get input param
opts, args = getopt(sys.argv[1:], '', ['data_dir=', 'save_dir=', 'seed=', 'gpu_device=', 'kfold_num=',
                                       'window_size=', 'batch_size=',
                                       'num_epochs=', 'stop_patience=', 'learning_rate=', "weight_decay=",
                                       'dropout_rate=',
                                       'beta_1=', 'beta_2=',
									   'convchannels=', 'kernel_size=', 'sampling_scale=', 'num_class=',
                                       'multi_task_weight_0=', 'multi_task_weight_1=', 'multi_task_weight_2=',
                                       'multi_task_weight_3=',
                                       'weight_0=', 'weight_1=', 'weight_2=', 'weight_3=',
                                       'mask_weight=', 'avg_window=',
                                       'train_val_cutrate=',
                                       'save_model_name=', 'let_go_rate='])

print("==========some new opts changed:==========")
for o, a in opts:
	if o == '--data_dir':
		par_dict['data_dir'] = a
		print(f"data_dir: {a}")
	if o == '--save_dir':
		par_dict['save_dir'] = a
		print(f"save_dir: {a}")
	if o == '--seed':
		par_dict['seed'] = int(a)
		print(f"seed: {a}")
	if o == '--kfold_num':
		par_dict['kfold_num'] = int(a)
		print(f"kfold_num: {a}")
	if o == '--window_size':
		par_dict['window_size'] = int(a)
		print(f"window_size: {a}")
	if o == '--batch_size':
		par_dict['batch_size'] = int(a)
		print(f"batch_size: {a}")
	if o == '--num_epochs':
		par_dict['num_epochs'] = int(a)
		print(f"num_epochs: {a}")
	if o == '--stop_patience':
		par_dict['stop_patience'] = int(a)
		print(f"stop_patience: {a}")
	if o == '--learning_rate':
		par_dict['learning_rate'] = float(a)
		print(f"learning_rate: {a}")
	if o == '--weight_decay':
		par_dict['weight_decay'] = float(a)
		print(f"weight_decay: {a}")
	if o == '--dropout_rate':
		par_dict['dropout_rate'] = float(a)
		print(f"dropout_rate: {a}")
	if o == '--beta_1':
		par_dict['beta_1'] = float(a)
		print(f"beta_1: {a}")
	if o == '--beta_2':
		par_dict['beta_2'] = float(a)
		print(f"beta_2: {a}")
	if o == '--let_go_rate':
		par_dict['let_go_rate'] = float(a)
		print(f"let_go_rate: {a}")
	if o == '--convchannels':
		par_dict['convchannels'] = int(a)
		print(f"convchannels: {a}")
	if o == '--kernel_size':
		par_dict['kernel_size'] = int(a)
		print(f"kernel_size: {a}")
	if o == '--sampling_scale':
		a = a[1:-1]
		par_dict['sampling_scale'] = [int(a.split(',')[0]), int(a.split(',')[1]), int(a.split(',')[2]), int(a.split(',')[3])]
		print(f"sampling_scale: {a}")
	if o == '--num_class':
		par_dict['num_class'] = int(a)
		print(f"num_class: {a}")
	if o == '--multi_task_weight_0':
		par_dict['multi_task_weight_0'] = torch.tensor(float(a))
		print(f"multi_task_weight_0: {a}")
	if o == '--multi_task_weight_1':
		par_dict['multi_task_weight_1'] = torch.tensor(float(a))
		print(f"multi_task_weight_1: {a}")
	if o == '--multi_task_weight_2':
		par_dict['multi_task_weight_2'] = torch.tensor(float(a))
		print(f"multi_task_weight_2: {a}")
	if o == '--multi_task_weight_3':
		par_dict['multi_task_weight_3'] = torch.tensor(float(a))
		print(f"multi_task_weight_3: {a}")
	if o == '--weight_0':
		par_dict['weight_0'] = torch.tensor(float(a))
		print(f"weight_0: {a}")
	if o == '--weight_1':
		par_dict['weight_1'] = torch.tensor(float(a))
		print(f"weight_1: {a}")
	if o == '--weight_2':
		par_dict['weight_2'] = torch.tensor(float(a))
		print(f"weight_2: {a}")
	if o == '--weight_3':
		par_dict['weight_3'] = torch.tensor(float(a))
		print(f"weight_3: {a}")
	if o == '--mask_weight':
		par_dict['mask_weight'] = torch.tensor(float(a))
		print(f"mask_weight: {a}")
	if o == '--avg_window':
		par_dict['avg_window'] = int(a)
		print(f"avg_window: {a}")
	if o == '--train_val_cutrate':
		par_dict['train_val_cutrate'] = float(a)
		print(f"train_val_cutrate: {a}")
	if o == '--save_model_name':
		par_dict['save_model_name'] = a
		print(f"save_model_name: {a}")

# The weight of different events
weight = [par_dict['weight_0'], par_dict['weight_1'], par_dict['weight_2'], par_dict['weight_3']]
par_dict['weight'] = torch.tensor(weight)

if not os.path.exists(par_dict["save_dir"]):
	os.makedirs(par_dict["save_dir"])

print("==========all opts:==========")
for key in par_dict:
	print(f"{key}: {par_dict[key]}")


def seed_torch(seed):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)  # Prohibit hash randomization, and let the experiment can be repeated.
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)  # If you are using multi-GPU.
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True


seed_torch(par_dict["seed"])

load_data = LoadData(par_dict["data_dir"], par_dict["window_size"], par_dict["kfold_num"])

# loss calculate func for HC tasks, which not considers the first output channel (background channel).
def prec_func_cal(pre, label, loss_func):
	pre = pre[1:]
	label = label[1:]
	loss = loss_func(pre, label)
	return loss

def train_fold(cur_fold, train_loader, val_loader):
	print(f"===============begin fold {cur_fold} train===============")

	net = EventComprehensionNet(3, par_dict["convchannels"], par_dict["kernel_size"], par_dict["sampling_scale"],
	                            par_dict["num_class"], par_dict["window_size"], 5,
								par_dict["dropout_rate"], par_dict["avg_window"]).cuda()

	optimizer = torch.optim.Adam(net.parameters(), lr=par_dict["learning_rate"], betas=(par_dict["beta_1"],
	                                                                                    par_dict["beta_2"]),
	                             weight_decay=par_dict["weight_decay"])
	loss_func1 = torch.nn.CrossEntropyLoss(weight=par_dict["weight"].cuda())
	loss_func2 = torch.nn.MSELoss()
	loss_func3 = torch.nn.BCEWithLogitsLoss(weight=par_dict["mask_weight"].cuda())  # The loss function for coarse-grained event perception task.

	best_f1 = 0
	to_stop = 0
	for epoch in range(par_dict["num_epochs"]):
		epoch_start_time = time.time()

		train_sum_loss = 0
		train_sum_num = 0
		net.train()
		net.zero_grad()
		for i, (batch_x, batch_y, batch_taskID) in enumerate(train_loader):
			batch_x = batch_x.cuda()
			batch_y = batch_y.cuda()
			batch_taskID = batch_taskID.cuda()

			y_pre = net(batch_x, batch_taskID)
			loss = None
			for i4 in range(len(batch_taskID)):
				if batch_taskID[i4] == 0:
					subloss = loss_func1(y_pre[i4].unsqueeze(0), batch_y[i4][0].unsqueeze(0).long()) * par_dict[
						"multi_task_weight_0"]  # Fine-Grained Event Detection Task (Main Task)
				elif batch_taskID[i4] == 4:
					subloss = loss_func3(y_pre[i4][0], batch_y[i4][0])  # Coarse-Grained Event Perception Task
				else:
					if batch_taskID[i4] == 1:
						subloss = prec_func_cal(y_pre[i4], batch_y[i4], loss_func2) * par_dict[
							"multi_task_weight_1"]  # Center Localization Task
					if batch_taskID[i4] == 2:
						subloss = prec_func_cal(y_pre[i4], batch_y[i4], loss_func2) * par_dict[
							"multi_task_weight_2"]  # Boundary Localization Task
					if batch_taskID[i4] == 3:
						subloss = prec_func_cal(y_pre[i4], batch_y[i4], loss_func2) * par_dict[
							"multi_task_weight_3"]  # Lifetime Analysis Task
				if loss is None:
					loss = subloss
				else:
					loss += subloss

			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			train_sum_loss += loss.item() * batch_x.shape[0]
			train_sum_num += batch_x.shape[0]

			batch_x = batch_x.cpu()
			batch_y = batch_y.cpu()
			batch_taskID = batch_taskID.cpu()
			y_pre = y_pre.cpu()
			if loss != 0:
				loss = loss.cpu()

			del batch_x, batch_y, y_pre, loss
			gc.collect()


		val_sum_loss = 0
		val_sum_num = 0
		val_succ_num = 0
		series_len = 0
		val_gt_y = []
		val_pre_y = []
		val_gt_mask = []
		val_pre_mask = []
		y_true_list = []
		y_scores_list = []

		net.eval()
		with torch.no_grad():
			for i, (batch_x, batch_y) in enumerate(val_loader):
				mask_gt = batch_y.clone().detach()
				mask_gt[mask_gt > 0] = 1
				main_taskID = torch.zeros(batch_x.shape[0], 1).long().cuda()
				mask_taskID = torch.full((batch_x.shape[0], 1), 4).long().cuda()
				batch_x = batch_x.cuda()
				batch_y = batch_y.cuda()

				y_pre = net(batch_x, main_taskID)
				mask_pre = net(batch_x, mask_taskID)

				mask_pre = mask_pre.transpose(0, 1)[0]
				mask_pre = torch.sigmoid(mask_pre)
				mask_pre = torch.where(mask_pre > par_dict['let_go_rate'], 1, 0)  # Straight to 0/1 ('let_go_rate' as the threshold)

				for i6 in range(len(batch_x)):
					# As long as there is a non-zero in mask, it is considered that there may be a waveform, and use event detection result y_pre.
					# If the lines of mask are all zero, the corresponding lines of y_pre are set to zero and background prediction are set to one.
					if torch.max(mask_pre[i6]) == 0:
						y_pre[i6][0] = torch.ones_like(y_pre[i6][0])   # background prediction
						y_pre[i6][1] = torch.zeros_like(y_pre[i6][1])  # class 1 prediction
						y_pre[i6][2] = torch.zeros_like(y_pre[i6][2])  # class 2 prediction
						y_pre[i6][3] = torch.zeros_like(y_pre[i6][3])  # class 3 prediction

				loss = loss_func1(y_pre, batch_y.long())

				main_taskID = main_taskID.cpu()
				mask_taskID = mask_taskID.cpu()
				batch_x = batch_x.cpu()
				batch_y = batch_y.cpu()
				y_pre = y_pre.cpu()
				mask_pre = mask_pre.cpu()
				loss = loss.cpu()

				val_sum_loss += loss * batch_x.shape[0]
				val_sum_num += batch_x.shape[0]

				y_scores = torch.softmax(y_pre, dim=1)
				y_true_list.append(batch_y.cpu().detach().numpy())
				y_scores_list.append(y_scores.cpu().detach().numpy())

				# Store and use to calculate metrics
				y_pre = F.softmax(y_pre, dim=1)
				y_pre = y_pre.argmax(1)
				if len(val_gt_y) == 0:
					val_gt_y = batch_y
					val_pre_y = y_pre
				else:
					val_gt_y = torch.cat((val_gt_y, batch_y), 0)
					val_pre_y = torch.cat((val_pre_y, y_pre), 0)

				# Only when an entire paragraph all zero, it is counted as background prediction for coarse-grained event perception.
				segment_mask_pre = torch.zeros(len(mask_pre))
				segment_mask_gt = torch.zeros(len(mask_gt))
				for i7 in range(len(mask_pre)):
					if torch.max(mask_pre[i7]) == 0:
						segment_mask_pre[i7] = 0
					else:
						segment_mask_pre[i7] = 1
				for i7 in range(len(mask_gt)):
					if torch.max(mask_gt[i7]) == 0:
						segment_mask_gt[i7] = 0
					else:
						segment_mask_gt[i7] = 1

				if len(val_gt_mask) == 0:
					val_gt_mask = segment_mask_gt
					val_pre_mask = segment_mask_pre
				else:
					val_gt_mask = torch.cat((val_gt_mask, segment_mask_gt), 0)
					val_pre_mask = torch.cat((val_pre_mask, segment_mask_pre), 0)

				succ_num = (y_pre == batch_y).sum()
				val_succ_num = val_succ_num + succ_num
				if series_len == 0:
					series_len = batch_y.shape[1]

				del mask_gt, main_taskID, mask_taskID, batch_x, batch_y, y_pre, mask_pre, segment_mask_gt, segment_mask_pre, succ_num
				gc.collect()

		# mAP
		y_true = np.concatenate(y_true_list)
		y_scores = np.concatenate(y_scores_list)
		aps = []

		# The calculation of mAP does not include the background class.
		for class_idx in range(1, par_dict["num_class"]):
			if not np.any(y_true == class_idx):
				continue
			class_true = (y_true == class_idx).astype(int)
			class_scores = y_scores[:, class_idx]
			class_ap = average_precision_score(class_true, class_scores)
			aps.append(class_ap)
		mean_ap = np.mean(aps)

		# Summary
		train_time = (time.time() - epoch_start_time)
		train_avg_loss = train_sum_loss / train_sum_num
		val_avg_loss = val_sum_loss / val_sum_num
		val_acc = float(val_succ_num) / (val_sum_num * series_len)

		val_gt_y = val_gt_y.detach().numpy()
		val_pre_y = val_pre_y.detach().numpy()
		val_event_f1 = cal_multiclass_event_level_f1(val_pre_y, val_gt_y, par_dict["num_class"])
		val_gt_y = val_gt_y.flatten()
		val_pre_y = val_pre_y.flatten()
		val_f1 = f1_score(val_gt_y, val_pre_y, average=None)
		val_avg_f1 = val_f1.mean()
		val_gt_oh = np.eye(par_dict["num_class"])[val_gt_y.astype(int)]
		val_pre_oh = np.eye(par_dict["num_class"])[val_pre_y.astype(int)]
		val_auc = roc_auc_score(val_gt_oh, val_pre_oh, multi_class='ovr')

		# For the classification of background or event
		val_mask_acc = float((val_gt_mask == val_pre_mask).sum()) / (val_gt_mask.shape[0])
		val_gt_mask = val_gt_mask.detach().numpy()
		val_pre_mask = val_pre_mask.detach().numpy()
		val_gt_mask = val_gt_mask.flatten()
		val_pre_mask = val_pre_mask.flatten()
		val_mask_f1 = f1_score(val_gt_mask, val_pre_mask, average=None)
		val_mask_avg_f1 = val_mask_f1.mean()

		print(
			'[ FOLD {} ] [ EPOCH{:3d} ] time: {:5.2f}s | train_avg_loss {:5.8f} | valid_avg_loss {:5.8f} | valid_acc {:5.8f} | valid_f1 [{:5.8f}, {:5.8f}, {:5.8f}, {:5.8f}] | valid_avg_f1 {:5.8f} | valid_auc {:5.8f} | valid_map {:5.8f} | avg_event_f1 {:5.8f} | segment_mask_acc {:5.8f} | segment_mask_f1 [{:5.8f}, {:5.8f}] | segment_mask_avg_f1 {:5.8f}'.format(
				cur_fold,
				epoch + 1,
				train_time,
				train_avg_loss,
				val_avg_loss,
				val_acc,
				val_f1[0], val_f1[1], val_f1[2], val_f1[3],
				val_avg_f1, val_auc, mean_ap, val_event_f1,
				val_mask_acc, val_mask_f1[0], val_mask_f1[1], val_mask_avg_f1,
			))

		# save
		if best_f1 < val_avg_f1:
			print("Model Saving....")
			best_f1 = val_avg_f1
			torch.save(net, par_dict["save_dir"] + f"fold{cur_fold}_" + par_dict["save_model_name"])
			print("model saved:", par_dict["save_dir"] + f"fold{cur_fold}_" + par_dict["save_model_name"])
			to_stop = 0
		else:
			to_stop = to_stop + 1
			if to_stop == par_dict["stop_patience"]:
				break


def test_result(cur_fold, test_loader):
	print(f"===============begin fold {cur_fold} test===============")

	net = torch.load(par_dict["save_dir"] + f"fold{cur_fold}_" + par_dict["save_model_name"])
	loss_func1 = torch.nn.CrossEntropyLoss(weight=par_dict["weight"].cuda()).cuda()

	test_sum_loss = 0
	test_sum_num = 0
	test_succ_num = 0
	bg_succ_num = 0
	series_len = 0
	ypre = None
	ytrue = None
	all_mask_pre = None
	all_mask_gt = None
	val_gt_y = []
	val_pre_y = []
	y_true_list = []
	y_scores_list = []

	net.eval()
	with torch.no_grad():
		for i, (batch_x, batch_y) in enumerate(test_loader):
			mask_gt = batch_y.clone().detach()
			mask_gt[mask_gt > 0] = 1

			main_taskID = torch.zeros(batch_x.shape[0], 1).long().cuda()
			mask_taskID = torch.full((batch_x.shape[0], 1), 4).long().cuda()
			batch_x = batch_x.cuda()
			batch_y = batch_y.cuda()

			y_pre = net(batch_x, main_taskID)
			mask_pre = net(batch_x, mask_taskID)
			mask_pre = mask_pre.transpose(0, 1)[0]
			mask_pre = torch.sigmoid(mask_pre)
			mask_pre = torch.where(mask_pre > par_dict['let_go_rate'], 1, 0)  # Straight to 0/1 ('let_go_rate' as the threshold)

			for i6 in range(len(batch_x)):
				# As long as there is a non-zero in mask, it is considered that there may be a waveform, and use event detection result y_pre.
				# If the lines of mask are all zero, the corresponding lines of y_pre are set to zero and background prediction are set to one.
				if torch.max(mask_pre[i6]) == 0:
					y_pre[i6][0] = torch.ones_like(y_pre[i6][0])
					y_pre[i6][1] = torch.zeros_like(y_pre[i6][1])
					y_pre[i6][2] = torch.zeros_like(y_pre[i6][2])
					y_pre[i6][3] = torch.zeros_like(y_pre[i6][3])

			loss = loss_func1(y_pre, batch_y.long())

			main_taskID = main_taskID.cpu()
			mask_taskID  = mask_taskID.cpu()
			batch_x = batch_x.cpu()
			batch_y = batch_y.cpu()
			y_pre = y_pre.cpu()
			mask_pre = mask_pre.cpu()
			loss = loss.cpu()

			test_sum_loss += loss * batch_x.shape[0]
			test_sum_num += batch_x.shape[0]

			y_scores = torch.softmax(y_pre, dim=1)
			y_true_list.append(batch_y.cpu().detach().numpy())
			y_scores_list.append(y_scores.cpu().detach().numpy())

			succ_num = F.softmax(y_pre, dim=1)
			succ_num = (succ_num.argmax(1) == batch_y).sum()
			test_succ_num = test_succ_num + succ_num
			if series_len == 0:
				series_len = batch_y.shape[1]

			ypre = y_pre if ypre is None else torch.cat((ypre, y_pre), dim=0)
			ytrue = batch_y if ytrue is None else torch.cat((ytrue, batch_y), dim=0)

			# Only an entire paragraph is counted as background.
			segment_mask_pre = torch.zeros(len(mask_pre))
			segment_mask_gt = torch.zeros(len(mask_gt))
			for i7 in range(len(mask_pre)):
				if torch.max(mask_pre[i7]) == 0:
					segment_mask_pre[i7] = 0
				else:
					segment_mask_pre[i7] = 1
			for i7 in range(len(mask_gt)):
				if torch.max(mask_gt[i7]) == 0:
					segment_mask_gt[i7] = 0
				else:
					segment_mask_gt[i7] = 1

			# Background statistics
			bg_succ_num += (segment_mask_pre == segment_mask_gt).sum()
			all_mask_pre = segment_mask_pre if all_mask_pre is None else torch.cat((all_mask_pre, segment_mask_pre), dim=0)
			all_mask_gt = segment_mask_gt if all_mask_gt is None else torch.cat((all_mask_gt, segment_mask_gt), dim=0)

			# Store and use to calculate metrics
			y_pre = F.softmax(y_pre, dim=1)
			y_pre = y_pre.argmax(1)
			if len(val_gt_y) == 0:
				val_gt_y = batch_y
				val_pre_y = y_pre
			else:
				val_gt_y = torch.cat((val_gt_y, batch_y), 0)
				val_pre_y = torch.cat((val_pre_y, y_pre), 0)

		del batch_x, batch_y, y_pre, mask_pre, mask_gt, segment_mask_gt, segment_mask_pre, succ_num, loss, mask_taskID, main_taskID
		gc.collect()

	# Tensor to numpy and to same shape
	ypre = ypre.cpu()
	ypre = ypre.detach().numpy()
	ypre = ypre.argmax(axis=1)
	ypre = ypre.flatten()
	ytrue = ytrue.cpu().numpy()
	ytrue = ytrue.flatten()

	# mAP
	y_true = np.concatenate(y_true_list)
	y_scores = np.concatenate(y_scores_list)
	aps = []

	# The calculation of mAP does not include the background class.
	for class_idx in range(1, par_dict["num_class"]):
		if not np.any(y_true == class_idx):
			continue
		class_true = (y_true == class_idx).astype(int)
		class_scores = y_scores[:, class_idx]
		class_ap = average_precision_score(class_true, class_scores)
		aps.append(class_ap)
	mean_ap = np.mean(aps)

	# Summary
	val_gt_y = val_gt_y.cpu().detach().numpy()
	val_pre_y = val_pre_y.cpu().detach().numpy()

	# Event-F1
	avg_event_f1, class_event_f1 = cal_multiclass_event_level_f1(val_pre_y, val_gt_y, par_dict["num_class"],
	                                                             print_result=True,
	                                                             return_class_f1=True)
	val_gt_y = val_gt_y.flatten()
	val_pre_y = val_pre_y.flatten()
	val_gt_oh = np.eye(par_dict["num_class"])[val_gt_y.astype(int)]
	val_pre_oh = np.eye(par_dict["num_class"])[val_pre_y.astype(int)]
	val_auc = roc_auc_score(val_gt_oh, val_pre_oh, multi_class='ovr')

	test_avg_loss = test_sum_loss / test_sum_num
	test_acc = float(test_succ_num) / (test_sum_num * series_len)
	label_class = ['None', 'StartHesitation', 'Turn', 'Walking'] # 4 category labels
	cm, avg_f1 = plot_and_print_cm(ypre, ytrue, par_dict["save_dir"], label_class,
	                               file_name=f"fold{cur_fold}_" + "pw_4class")

	bg_acc = float(bg_succ_num) / test_sum_num
	bg_cm, bg_avg_f1 = plot_and_print_cm(all_mask_pre, all_mask_gt, par_dict["save_dir"], ["bg", "not_bg"],
	                                     file_name=f"fold{cur_fold}_" + "bg")

	print(
		'[ FOLD {} ] end of test | test_avg_f1 {:5.8f} | test_avg_loss {:5.8f} | test_acc {:5.8f} | test_auc {:5.8f} | test_map {:5.8f}'.format(
			cur_fold,
			avg_f1,
			test_avg_loss,
			test_acc, val_auc,
			mean_ap))
	print(
		'[ FOLD {} ] avg_event_f1 {:5.8f} | class_event_f1 [ {:5.8f}, {:5.8f}, {:5.8f} ]'.format(
			cur_fold, avg_event_f1, class_event_f1[0], class_event_f1[1], class_event_f1[2]))
	print(
		'[ FOLD {} ] segment_mask_avg_f1 {:5.8f} | segment_mask_acc {:5.8f}'.format(
			cur_fold, bg_avg_f1,
			bg_acc, ))

	return avg_f1, test_acc, val_auc, mean_ap, bg_avg_f1, bg_acc, cm, bg_cm, avg_event_f1


def cal_fold_mean(fold_result_list):
	fold_result_list = np.array(fold_result_list)
	return np.mean(fold_result_list)


def kfold_train_test():
	fold_f1_result = []
	fold_acc_result = []
	fold_auc_result = []
	fold_map_result = []
	fold_bg_f1_result = []
	fold_bg_acc_result = []
	fold_event_f1_result = []
	fold_cm = None
	fold_bg_cm = None

	for cur_fold in range(par_dict["kfold_num"]):
		train_dataset, val_dataset, test_dataset = load_data.gen_fold_data(cur_fold, par_dict["train_val_cutrate"])

		# Dataloader
		train_loader = DataLoader(train_dataset, batch_size=par_dict["batch_size"], shuffle=True)
		val_loader = DataLoader(val_dataset, batch_size=par_dict["batch_size"], shuffle=False)
		test_loader = DataLoader(test_dataset, batch_size=par_dict["batch_size"], shuffle=False)

		train_fold(cur_fold, train_loader, val_loader)
		avg_f1, test_acc, val_auc, mean_ap, bg_avg_f1, bg_acc, cm, bg_cm, avg_event_f1 = test_result(cur_fold, test_loader)
		fold_f1_result.append(avg_f1)
		fold_acc_result.append(test_acc)
		fold_auc_result.append(val_auc)
		fold_map_result.append(mean_ap)
		fold_bg_f1_result.append(bg_avg_f1)
		fold_bg_acc_result.append(bg_acc)
		fold_event_f1_result.append(avg_event_f1)
		if fold_cm is None:
			fold_cm = cm
			fold_bg_cm = bg_cm
		else:
			fold_cm += fold_cm
			fold_bg_cm += bg_cm

	# The mean value of fold_result
	print(f"===============k fold mean result===============")
	kfold_f1 = cal_fold_mean(fold_f1_result)
	kfold_acc = cal_fold_mean(fold_acc_result)
	kfold_auc = cal_fold_mean(fold_auc_result)
	kfold_map = cal_fold_mean(fold_map_result)
	kfold_bg_f1 = cal_fold_mean(fold_bg_f1_result)
	kfold_bg_acc = cal_fold_mean(fold_bg_acc_result)
	kfold_event_f1 = cal_fold_mean(fold_event_f1_result)

	print("[ ALLFOLD ] kfold_mean_marco_f1: ", kfold_f1)
	print("[ ALLFOLD ] kfold_mean_acc: ", kfold_acc)
	print("[ ALLFOLD ] kfold_mean_auc: ", kfold_auc)
	print("[ ALLFOLD ] kfold_mean_map: ", kfold_map)
	print("[ ALLFOLD ] kfold_mean_bg_f1: ", kfold_bg_f1)
	print("[ ALLFOLD ] kfold_mean_bg_acc: ", kfold_bg_acc)
	print("[ ALLFOLD ] kfold_mean_event_f1: ", kfold_event_f1)

	# Plot the total confusion matrix
	label_class = ['None', 'StartHesitation', 'Turn', 'Walking']
	bg_label_class = ["bg", "not_bg"]
	plot_confusion_matrix(fold_cm, classes=label_class, title='allfold_pw_4class_' + 'cm_', path=par_dict["save_dir"])
	plot_confusion_matrix(fold_cm, classes=label_class, title='allfold_pw_4class_' + 'cm_num', normalize=False,
	                      path=par_dict["save_dir"])

	plot_confusion_matrix(fold_bg_cm, classes=bg_label_class, title='allfold_bg_' + 'cm_', path=par_dict["save_dir"])
	plot_confusion_matrix(fold_bg_cm, classes=bg_label_class, title='allfold_bg_' + 'cm_num', normalize=False,
	                      path=par_dict["save_dir"])

	mirco_f1 = f1_scores_from_cm(fold_cm)
	mean_mirco_f1 = np.mean(mirco_f1)

	print("[ ALLFOLD ] kfold_mirco_f1: ", mirco_f1)
	print("[ ALLFOLD ] kfold_mean_mirco_f1: ", mean_mirco_f1)



kfold_train_test()
