import os
import numpy as np
import math
import torch
from my_dataset import MyDatasetThreePre, MyDataset, MyDatasetOnePre


class LoadData:
	def __init__(self, data_dir, window_size, fold_num):
		self.fold_num = fold_num

		data = []
		label = [] # The label of fine-grained event detection task (main task).
		pre1_label = []  # The label of center localization task.
		pre2_label = []  # The label of boundary localization task.
		pre3_label = []  # The label of lifetime analysis task.
		for name in os.listdir(data_dir + "point_data/"):
			p_name = name[:-4]
			ex = np.load(os.path.join(data_dir, "point_data", name))
			ex = ex.transpose(1, 0)
			ex0 = np.lib.stride_tricks.sliding_window_view(ex[0], window_size)[::window_size]
			ex1 = np.lib.stride_tricks.sliding_window_view(ex[1], window_size)[::window_size]
			ex2 = np.lib.stride_tricks.sliding_window_view(ex[2], window_size)[::window_size]
			ex_end = np.lib.stride_tricks.sliding_window_view(ex[3], window_size)[::window_size]

			ex_data = np.append(ex0[np.newaxis, :], ex1[np.newaxis, :], axis=0)
			ex_data = np.append(ex_data, ex2[np.newaxis, :], axis=0)

			pre1 = np.load(data_dir + "preception_label/" + "center_localization_label_" + p_name + ".npy")
			pre1_c0 = np.lib.stride_tricks.sliding_window_view(pre1[0], window_size)[
			          ::window_size]
			pre1_c1 = np.lib.stride_tricks.sliding_window_view(pre1[1], window_size)[
			          ::window_size]
			pre1_c2 = np.lib.stride_tricks.sliding_window_view(pre1[2], window_size)[
			          ::window_size]
			pre1_c3 = np.lib.stride_tricks.sliding_window_view(pre1[3], window_size)[
			          ::window_size]
			pre1 = np.append(pre1_c0[np.newaxis, :], pre1_c1[np.newaxis, :], axis=0)
			pre1 = np.append(pre1, pre1_c2[np.newaxis, :], axis=0)
			pre1 = np.append(pre1, pre1_c3[np.newaxis, :], axis=0)

			pre2 = np.load(data_dir + "preception_label/" + "boundary_localization_label_" + p_name + ".npy")
			pre2_c0 = np.lib.stride_tricks.sliding_window_view(pre2[0], window_size)[
			          ::window_size]
			pre2_c1 = np.lib.stride_tricks.sliding_window_view(pre2[1], window_size)[
			          ::window_size]
			pre2_c2 = np.lib.stride_tricks.sliding_window_view(pre2[2], window_size)[
			          ::window_size]
			pre2_c3 = np.lib.stride_tricks.sliding_window_view(pre2[3], window_size)[
			          ::window_size]
			pre2 = np.append(pre2_c0[np.newaxis, :], pre2_c1[np.newaxis, :], axis=0)
			pre2 = np.append(pre2, pre2_c2[np.newaxis, :], axis=0)
			pre2 = np.append(pre2, pre2_c3[np.newaxis, :], axis=0)

			pre3 = np.load(data_dir + "preception_label/" + "lifetime_analysis_label_" + p_name + ".npy")
			pre3_c0 = np.lib.stride_tricks.sliding_window_view(pre3[0], window_size)[
			          ::window_size]
			pre3_c1 = np.lib.stride_tricks.sliding_window_view(pre3[1], window_size)[
			          ::window_size]
			pre3_c2 = np.lib.stride_tricks.sliding_window_view(pre3[2], window_size)[
			          ::window_size]
			pre3_c3 = np.lib.stride_tricks.sliding_window_view(pre3[3], window_size)[
			          ::window_size]
			pre3 = np.append(pre3_c0[np.newaxis, :], pre3_c1[np.newaxis, :], axis=0)
			pre3 = np.append(pre3, pre3_c2[np.newaxis, :], axis=0)
			pre3 = np.append(pre3, pre3_c3[np.newaxis, :], axis=0)

			if len(data) == 0:
				data = ex_data
				label = ex_end
				pre1_label = pre1
				pre2_label = pre2
				pre3_label = pre3
			else:
				data = np.append(data, ex_data, axis=1)
				label = np.append(label, ex_end, axis=0)
				pre1_label = np.append(pre1_label, pre1, axis=1)
				pre2_label = np.append(pre2_label, pre2, axis=1)
				pre3_label = np.append(pre3_label, pre3, axis=1)

		print("load data shape:" + str(data.shape))

		x = data
		y = label

		# Generate label for coarse-grained event perception task, we call it "mask" here.
		mask = np.array(y)
		mask[mask > 0] = 1

		# Switch channel
		x = x.transpose(1, 0, 2)
		pre1_label = pre1_label.transpose(1, 0, 2)
		pre2_label = pre2_label.transpose(1, 0, 2)
		pre3_label = pre3_label.transpose(1, 0, 2)

		# To device
		self.x = torch.from_numpy(x).float()
		self.y = torch.from_numpy(y).float()
		self.mask = torch.from_numpy(mask).float()
		self.pre1_label = torch.from_numpy(pre1_label).float()
		self.pre2_label = torch.from_numpy(pre2_label).float()
		self.pre3_label = torch.from_numpy(pre3_label).float()

	def gen_fold_data(self, fold, train_val_cut_rate):
		print(f"===============gen fold {fold} data===============")

		one_fold_num = int(len(self.x) / self.fold_num)
		print("one_fold_num: ", one_fold_num)

		# One fold for testing, and other four folds for training and validation.
		test_cut_begin = one_fold_num * fold
		test_cut_end = one_fold_num * (fold + 1)
		print("test_cut_begin: ", test_cut_begin, "test_cut_end: ", test_cut_end)

		train_val_x = torch.cat((self.x[0:test_cut_begin], self.x[test_cut_end:]), dim=0)
		val_cut_begin = int(len(train_val_x) * train_val_cut_rate)
		train_x = train_val_x[0:val_cut_begin]
		val_x = train_val_x[val_cut_begin:]
		test_x = self.x[test_cut_begin: test_cut_end]

		train_val_y = torch.cat((self.y[0:test_cut_begin], self.y[test_cut_end:]), dim=0)
		train_y = train_val_y[0:val_cut_begin]
		val_y = train_val_y[val_cut_begin:]
		test_y = self.y[test_cut_begin: test_cut_end]

		train_val_mask = torch.cat((self.mask[0:test_cut_begin], self.mask[test_cut_end:]), dim=0)
		train_mask = train_val_mask[0:val_cut_begin]

		train_val_pre1 = torch.cat((self.pre1_label[0:test_cut_begin], self.pre1_label[test_cut_end:]), dim=0)
		train_pre1 = train_val_pre1[0:val_cut_begin]

		train_val_pre2 = torch.cat((self.pre2_label[0:test_cut_begin], self.pre2_label[test_cut_end:]), dim=0)
		train_pre2 = train_val_pre2[0:val_cut_begin]

		train_val_pre3 = torch.cat((self.pre3_label[0:test_cut_begin], self.pre3_label[test_cut_end:]), dim=0)
		train_pre3 = train_val_pre3[0:val_cut_begin]

		print("In unbalanced environment, the length of train_x: ", len(train_x), "the length of val_x: ", len(val_x), "the length of test_x: ", len(test_x))

		# Build class-balanced training environment.
		# Except for mask, which is trained on full data,
		# other tasks are trained on data of class-balanced environment (with at least one waveform in each sequence).
		balance_train_x = None
		balance_train_y = None
		balance_train_pre1 = None
		balance_train_pre2 = None
		balance_train_pre3 = None
		unbalance_count = 0
		for i in range(len(train_x)):
			if torch.max(train_mask[i]) > 0:
				if balance_train_x is None:
					balance_train_x = train_x[i].unsqueeze(0)
					balance_train_y = train_y[i].unsqueeze(0)
					balance_train_pre1 = train_pre1[i].unsqueeze(0)
					balance_train_pre2 = train_pre2[i].unsqueeze(0)
					balance_train_pre3 = train_pre3[i].unsqueeze(0)
				else:
					balance_train_x = torch.cat((balance_train_x, train_x[i].unsqueeze(0)), axis=0)
					balance_train_y = torch.cat((balance_train_y, train_y[i].unsqueeze(0)), axis=0)
					balance_train_pre1 = torch.cat((balance_train_pre1, train_pre1[i].unsqueeze(0)), axis=0)
					balance_train_pre2 = torch.cat((balance_train_pre2, train_pre2[i].unsqueeze(0)), axis=0)
					balance_train_pre3 = torch.cat((balance_train_pre3, train_pre3[i].unsqueeze(0)), axis=0)
			else:
				unbalance_count += 1

		print("total training data: ", len(train_x), "pure background training data: ", unbalance_count)

		print("data with events (for class-balanced training): ", len(balance_train_x), "all data (for class-unbalanced training): ", len(train_x))

		# Data enhancement: multiply the balance training data/label several times to combat background class for inter-task balance
		multiple = math.ceil(len(train_x) / len(balance_train_x))
		balance_train_x_copy = balance_train_x
		balance_train_y_copy = balance_train_y
		balance_train_pre1_copy = balance_train_pre1
		balance_train_pre2_copy = balance_train_pre2
		balance_train_pre3_copy = balance_train_pre3
		for i in range(multiple-1):
			balance_train_x = torch.cat((balance_train_x, balance_train_x_copy), axis=0)
			balance_train_y = torch.cat((balance_train_y, balance_train_y_copy), axis=0)
			balance_train_pre1 = torch.cat((balance_train_pre1, balance_train_pre1_copy), axis=0)
			balance_train_pre2 = torch.cat((balance_train_pre2, balance_train_pre2_copy), axis=0)
			balance_train_pre3 = torch.cat((balance_train_pre3, balance_train_pre3_copy), axis=0)

		print("enhanced data (for class-balanced training): ", len(balance_train_x))

		# Generate taskID
		t0 = torch.full((balance_train_x.shape[0], 1), 0)
		t1 = torch.full((balance_train_x.shape[0], 1), 1)
		t2 = torch.full((balance_train_x.shape[0], 1), 2)
		t3 = torch.full((balance_train_x.shape[0], 1), 3)
		t4 = torch.full((train_x.shape[0], 1), 4)

		# Only mask is on the full training data.
		train_x = torch.cat((balance_train_x, balance_train_x, balance_train_x, balance_train_x, train_x), axis=0)
		train_taskID = torch.cat((t0, t1, t2, t3, t4), axis=0)

		balance_train_y = torch.unsqueeze(balance_train_y, dim=1)
		balance_train_y = balance_train_y.expand(balance_train_y.shape[0], 4, balance_train_y.shape[2])
		train_mask = torch.unsqueeze(train_mask, dim=1)
		train_mask = train_mask.expand(train_mask.shape[0], 4, train_mask.shape[2])

		train_y = torch.cat((balance_train_y, balance_train_pre1, balance_train_pre2, balance_train_pre3, train_mask),axis=0)

		train_dataset = MyDatasetOnePre(train_x, train_y, train_taskID)  # special dataset which contains taskID
		val_dataset = MyDataset(val_x, val_y)
		test_dataset = MyDataset(test_x, test_y)

		print("train data shape: ", len(train_x))
		print("val data shape: ", len(val_x))
		print("test data shape: ", len(test_x))

		return train_dataset, val_dataset, test_dataset