import numpy as np
from tensorflow import keras
import h5py
import hdf5storage
import math
import scipy.io
from scipy import signal
import skimage
from skimage.transform import resize
import os
import cv2

import sys
sys.path.append("Post-Processing")
from calculate_metrics import calc_PPG_peaks
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#
# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#   # Restrict TensorFlow to only use the first GPU
#   try:
#     tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
#     logical_gpus = tf.config.experimental.list_logical_devices('GPU')
#     print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
#   except RuntimeError as e:
#     # Visible devices must be set before GPUs have been initialized
#     print(e)

def get_zero_crossings(sig):
    # get zero crossings by looking for sign changes
    start_indices = np.where(np.diff(np.sign(sig)))[0]
    # get the indices of the first value after the crossing as well
    end_indices = start_indices + 1
    return list(start_indices) + list(end_indices)


class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, paths_of_videos, nframe_per_video, dim, num_gpu, batch_size=32, frame_depth=10,
                 shuffle=True, temporal=True, respiration=0, crop_size=36, data_aug=0, dataset='AFRL',
                 step_size=1, split_iter=False, return_file_names=False, seg_masks=None, use_first_derivative=True,
                 use_second_derivative=False, use_raw_signal=False, signals_to_use=["dysub", "drsub"]):
        self.dim = dim
        self.batch_size = batch_size
        self.paths_of_videos = list(filter(None, paths_of_videos))
        self.nframe_per_video = nframe_per_video
        self.shuffle = shuffle
        self.temporal = temporal
        self.frame_depth = frame_depth
        self.respiration = respiration
        self.num_gpu = num_gpu
        self.random_crop_size = crop_size
        self.data_aug = data_aug
        self.dataset = dataset
        self.step_size = step_size # number of frames in between windows within a video segment
        self.split_iter = split_iter # if true, split file iteration across files (i.e. iterate files out-of-order)
        self.return_file_names = return_file_names # if true, return the names of the files used for each video as part of the batch label dict
        self.seg_masks = seg_masks # list of segmentation masks to use
        self.use_first_derivative = use_first_derivative # if True, include first derivative as target signal
        self.use_second_derivative = use_second_derivative # if True, include second derivative as target signal
        self.use_raw_signal = use_raw_signal # if True, include raw signal (0th derivative) as target signal
        self.signals_to_use = signals_to_use # list of the signals to return from the generator
        self.num_window = int((self.nframe_per_video - (self.frame_depth)) / self.step_size) + 1
        self.on_epoch_end()
        print("NUM WINDOWS", self.num_window)

    def data_load_func(self, path):
        try:
            f1 = h5py.File(path, 'r')
        except OSError:
            f1 = hdf5storage.loadmat(path)

        if f1["dXsub"].shape[0] == 6:
            dXsub = np.transpose(np.array(f1["dXsub"]))
        else:
            dXsub = np.array(f1["dXsub"])
        return f1, dXsub

    def __len__(self):
        'Denotes the number of batches per epoch'
        return math.ceil(len(self.paths_of_videos)*self.num_window / self.batch_size)

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Find list of IDs
        if self.split_iter:
            # if split_iter, split window iteration across files 
            list_IDs_temp = [self.paths_of_videos[k % len(self.paths_of_videos)] for k in indexes]
        else:
            # else traverse files in order
            list_IDs_temp = [self.paths_of_videos[k // (self.num_window)] for k in indexes]
        X, y = self.__data_generation(list_IDs_temp, indexes)
        return X, y

    def on_epoch_end(self):
        # 'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.paths_of_videos)*self.num_window)
        if self.shuffle:
            np.random.shuffle(self.indexes)
        #self.central_fraction = np.random.uniform(0, 0.5)


    def random_crop(self, img):
        assert img.shape[2] == 6
        motion_img = img[:, :, :3]
        appearance_img = img[:, :, 3:]
        height, width = motion_img.shape[0], motion_img.shape[1]
        dy, dx = self.random_crop_size
        x = np.random.randint(0, width - dx + 1)
        y = np.random.randint(0, height - dy + 1)
        cropped_motion = motion_img[y:(y + dy), x:(x + dx), :]
        cropped_apperance = appearance_img[y:(y + dy), x:(x + dx), :]
        cropped_img = np.concatenate((cropped_motion, cropped_apperance), axis=-1)
        return cropped_img

    def central_crop(self, img):
        assert img.shape[2] == 6
        motion_img = img[:, :, :3]
        appearance_img = img[:, :, 3:]
        height, width = motion_img.shape[0], motion_img.shape[1]
        bbox_h_start = int((height - height * self.central_fraction) / 2)
        bbox_w_start = int((width - width * self.central_fraction) / 2)
        bbox_h_end = height - bbox_h_start * 2
        bbox_w_end = width - bbox_w_start * 2
        cropped_motion = motion_img[bbox_h_start:bbox_h_end, bbox_w_start:bbox_w_end, :]
        cropped_apperance = appearance_img[bbox_h_start:bbox_h_end, bbox_w_start:bbox_w_end, :]
        cropped_img = np.concatenate((cropped_motion, cropped_apperance), axis=-1)
        return cropped_img

    def random_flip(self, img):
        random_seed = np.random.uniform(0, 1)
        if random_seed < 0.5:
            return np.fliplr(img)
        else:
            return img

    def data_aug_func(self, img):
        img = self.random_crop(img)
        img = self.random_flip(img)
        return img

    def visualize_modification(self, orig_vid, mod_vid, title=""):
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(2, 2, figsize=(16, 8))
        fig.suptitle("{} Augmentation".format(title))
        ax[0, 0].set_title("Motion")
        ax[0, 0].imshow(orig_vid[0, 1, :, :, 0:3])
        ax[0, 1].set_title("Appearance")
        ax[0, 1].imshow(orig_vid[0, :, :, :, 3:].mean(axis=0))

        ax[1, 0].set_title("Augmented")
        ax[1, 0].imshow(mod_vid[0, 1, :, :, 0:3])
        ax[1, 1].imshow(mod_vid[0, :, :, :, 3:].mean(axis=0))

        for i in range(ax.shape[0]):
            for j in range(ax.shape[1]):
                ax[i, j].get_xaxis().set_visible(False)
                ax[i, j].get_yaxis().set_visible(False)
        plt.show()

    def random_translate_vid(self, vid, chance=0.1):
        for i in range(vid.shape[0]):
            if np.random.uniform() < chance:
                # order should be [batch, time, rows, columns, channels]
                height, width = vid.shape[2], vid.shape[3]
                shifted_vid = np.zeros(shape=vid.shape[1:])
                dy, dx = self.random_crop_size
                x = np.random.randint(0, width - dx + 1)
                y = np.random.randint(0, height - dy + 1)
                # calculate random shift 
                shift_x = np.random.randint(0, x+1)
                shift_y = np.random.randint(0, y+1)
                # fill in empty video with cropped video
                shifted_vid[:, shift_y:(shift_y + dy), shift_x:(shift_x + dx), :] = vid[i, :, y:(y + dy), x:(x + dx), :]
                # replace original video with cropped video
                vid[i] = shifted_vid
        return vid

    def random_crop_vid(self, vid, chance=0.1):
        for i in range(vid.shape[0]):
            if np.random.uniform() < chance:
                # order should be [batch, time, rows, columns, channels]
                height, width = vid.shape[2], vid.shape[3]
                cropped_vid = np.zeros(shape=vid.shape[1:])
                dy, dx = self.random_crop_size
                x = np.random.randint(0, width - dx + 1)
                y = np.random.randint(0, height - dy + 1)
                # fill in empty video with cropped video
                cropped_vid[:, y:(y + dy), x:(x + dx), :] = vid[i, :, y:(y + dy), x:(x + dx), :]
                # replace original video with cropped video
                vid[i] = cropped_vid
        return vid

    def random_flip_vid(self, vid, chance=0.5):
        random_seed = np.random.uniform(0, 1)
        if random_seed < chance:
            # flip the image column ordering 
            # order should be [batch, time, rows, columns, channels]
            vid = vid[:, :, :, ::-1, :]
        return vid

    def random_noise_vid(self, vid, max_noise=0.05):
        # draw variance from uniform distribution
        noise_amount = np.random.uniform(0, max_noise)
        # draw per-pixel noise value from normal distribution 
        vid[:, :, :, :, 3:] = vid[:, :, :, :, 3:] + np.random.normal(0, noise_amount, size=vid[:, :, :, :, 3:].shape).astype('float32')
        # clip negative pixel values
        # vid[vid < 0] = 0
        return vid

    def random_contrast_vid(self, vid, random_contrast_delta=0.2):
        contrast_range = -random_contrast_delta, random_contrast_delta
        contrast = np.random.uniform(*contrast_range)
        # order should be [batch, time, rows, columns, channels]
        vid[:, :, :, :, 3:] = (vid[:, :, :, :, -3:] - 0.5) * (contrast + 1) + 0.5
        return vid

    def random_brightness_vid(self, vid, random_brightness_delta=0.2):
        brightness_range = -random_brightness_delta, random_brightness_delta
        brightness = np.random.uniform(*brightness_range)
        # order should be [batch, time, rows, columns, channels]
        vid[:, :, :, :, -3:] += brightness
        return vid

    def random_hue_vid(self, vid, chance=0.1, hue_shift=180):
        # for each video in batch
        for i in range(vid.shape[0]):
            if np.random.uniform() < chance:
                hue_shift = np.random.uniform(-hue_shift, hue_shift)
                # for each time step
                for t in range(vid.shape[1]):
                    hsv = cv2.cvtColor(np.float32(vid[i, t, :, :, -3:]), cv2.COLOR_RGB2HSV)
                    hsv[..., 0] = np.mod(hsv[..., 0] + hue_shift, 360)
                    vid[i, t, :, :, -3:] = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
        return vid

    def random_saturation_vid(self, vid, chance=0.1, saturation_shift=0.1):
        # for each video in batch
        for i in range(vid.shape[0]):
            if np.random.uniform() < chance:
                saturation_shift = np.random.uniform(-saturation_shift, saturation_shift)
                # for each time step
                for t in range(vid.shape[1]):
                    hsv = cv2.cvtColor(np.float32(vid[i, t, :, :, -3:]), cv2.COLOR_RGB2HSV)
                    hsv[..., 1] = np.clip(hsv[..., 1] + saturation_shift, 0, 1.0)
                    vid[i, t, :, :, -3:] = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
        return vid

    def random_jpeg_vid(self, vid, chance=0.1, quality_range=(10, 90)):
        # for each video in batch
        for i in range(vid.shape[0]):
            if np.random.uniform() < chance:
                # for each time step
                for t in range(vid.shape[1]):
                    img_uint8 = (vid[i, t, :, :, -3:]).astype(np.uint8)

                    quality = np.random.uniform(*quality_range)
                    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
                    _, encoded = cv2.imencode('.jpg', img_uint8, encode_param)

                    img_uint8_decoded = cv2.imdecode(encoded, cv2.IMREAD_UNCHANGED)

                    # imdecode will convert (H, W, 1) to (H, W) handling that case here
                    img_uint8_decoded = np.reshape(img_uint8_decoded, img_uint8.shape)

                    vid[i, t, :, :, -3:] = np.float32(img_uint8_decoded / 255.0)
        return vid

    def random_blur_vid(self, vid, chance=0.1, blur_amount_range=(0.1, 0.2)):
        def _random_kernel_size(image_size, blur_amount_range):
            # Base kernel size on image size for comparable results accross image resolutions
            blur_amount = np.random.uniform(blur_amount_range[0], blur_amount_range[1])
            ksize = round(blur_amount * image_size)
            ksize = ksize if ksize % 2 == 1 else ksize + 1  # Make sure kernel size is odd
            return ksize

        ksize = _random_kernel_size(vid.shape[2], blur_amount_range=blur_amount_range)
        for i in range(vid.shape[0]):
            if np.random.uniform() < chance:
                for t in range(vid.shape[1]):
                    # blur each raw frame
                    # order should be [batch, time, rows, columns, channels]
                    vid[i, t, :, :, -3:] = cv2.GaussianBlur(vid[i, t, :, :, -3:], ksize=(ksize, ksize), sigmaX=0, sigmaY=0)
            # Support for (H, W, 1) (mono) images as cv2 will perform an implicit squeeze
            # sample.bgr_img = filtered.reshape(sample.bgr_img.shape)
        return vid

    def vid_aug_func(self, vid):
        max_val = np.max(vid[0, 0, :, :, -3:])
        vid = self.random_translate_vid(vid)
        vid = self.random_crop_vid(vid)
        vid = self.random_flip_vid(vid)
        vid = self.random_noise_vid(vid, max_noise=0.05*max_val)
        vid = self.random_blur_vid(vid)
        vid = self.random_contrast_vid(vid, random_contrast_delta=0.2*max_val)
        vid = self.random_brightness_vid(vid, random_brightness_delta=0.2*max_val)
        vid = self.random_hue_vid(vid, chance=0.1, hue_shift=180)
        vid = self.random_saturation_vid(vid, chance=0.1, saturation_shift=0.1*max_val)
        # # # vid = self.random_jpeg_vid(vid, chance=1.01,)

        # orig_vid = vid.copy()

        # mod_vid = self.random_translate_vid(vid.copy(), chance=1.01)
        # self.visualize_modification(orig_vid, mod_vid, title="Random Translate")
        # mod_vid = self.random_crop_vid(vid.copy(), chance=1.01)
        # self.visualize_modification(orig_vid, mod_vid, title="Random Crop")
        # mod_vid = self.random_noise_vid(vid.copy(), max_noise=0.5*max_val)
        # self.visualize_modification(orig_vid, mod_vid, title="Random Noise")
        # mod_vid = self.random_contrast_vid(vid.copy(), random_contrast_delta=0.5*max_val)
        # self.visualize_modification(orig_vid, mod_vid, title="Random Contrast")
        # # # mod_vid = self.random_brightness_vid(vid.copy(), random_brightness_delta=0.5*max_val)
        # # # self.visualize_modification(orig_vid, mod_vid, title="Random Brightness")
        # # mod_vid = self.random_hue_vid(vid.copy(), chance=1.01, hue_shift=180)
        # # self.visualize_modification(orig_vid, mod_vid, title="Random Hue")
        # # mod_vid = self.random_saturation_vid(vid.copy(), chance=1.01, saturation_shift=0.5*max_val)
        # # self.visualize_modification(orig_vid, mod_vid, "Random Saturation")
        # mod_vid = self.random_jpeg_vid(vid.copy(), chance=1.01,)
        # self.visualize_modification(orig_vid, mod_vid, title="Random JPEG")
        # mod_vid = self.random_blur_vid(vid.copy(), chance=1.01,)
        # self.visualize_modification(orig_vid, mod_vid, title="Random Blur")

        return vid

    def __data_generation(self, list_video_temp, indexes):
        'Generates data containing batch_size samples'

        if self.respiration == 1:
            label_key = "drsub"
        else:
            label_key = 'dysub'

        if self.temporal == '3DCNN':
            num_window = self.nframe_per_video - (self.frame_depth + 1)
            data = np.zeros((num_window*len(list_video_temp), self.dim[0], self.dim[1], self.frame_depth, 6),
                            dtype=np.float32)
            label = np.zeros((num_window*len(list_video_temp), self.frame_depth), dtype=np.float32)
            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                dysub = np.array(f1[label_key])
                tempX = np.array([dXsub[f:f + self.frame_depth, :, :, :] # (169, 10, 36, 36, 6)
                                  for f in range(num_window)])
                tempY = np.array([dysub[f:f + self.frame_depth] # (169, 10, 1)
                                  for f in range(num_window)])
                tempX = np.swapaxes(tempX, 1, 3) # (169, 36, 36, 10, 6)
                tempX = np.swapaxes(tempX, 1, 2) # (169, 36, 36, 10, 6)
                tempY = np.reshape(tempY, (num_window, self.frame_depth)) # (169, 10)
                data[index*num_window:(index+1)*num_window, :, :, :, :] = tempX
                label[index*num_window:(index+1)*num_window, :] = tempY
            output = (data[:, :, :, :, :3], data[:, :, :, :, -3:])
        elif self.temporal == '3DCNN-MT':
            num_window = self.nframe_per_video - (self.frame_depth + 1)
            data = np.zeros((num_window*len(list_video_temp), self.dim[0], self.dim[1], self.frame_depth, 6),
                            dtype=np.float32)
            label_y = np.zeros((num_window*len(list_video_temp), self.frame_depth), dtype=np.float32)
            label_r = np.zeros((num_window * len(list_video_temp), self.frame_depth), dtype=np.float32)
            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                drsub = np.array(f1['drsub'])
                dysub = np.array(f1['dysub'])
                tempX = np.array([dXsub[f:f + self.frame_depth, :, :, :] # (169, 10, 36, 36, 6)
                                  for f in range(num_window)])
                tempY_y = np.array([dysub[f:f + self.frame_depth] # (169, 10, 1)
                                  for f in range(num_window)])
                tempY_r = np.array([drsub[f:f + self.frame_depth] # (169, 10, 1)
                                  for f in range(num_window)])
                tempX = np.swapaxes(tempX, 1, 3) # (169, 36, 36, 10, 6)
                tempX = np.swapaxes(tempX, 1, 2) # (169, 36, 36, 10, 6)
                tempY_y = np.reshape(tempY_y, (num_window, self.frame_depth)) # (169, 10)
                tempY_r = np.reshape(tempY_r, (num_window, self.frame_depth))  # (169, 10)
                data[index*num_window:(index+1)*num_window, :, :, :, :] = tempX
                label_y[index*num_window:(index+1)*num_window, :] = tempY_y
                label_r[index * num_window:(index + 1) * num_window, :] = tempY_r
            output = (data[:, :, :, :, :3], data[:, :, :, :, 3:])
            label = (label_y, label_r)
        elif self.temporal == '2DCNN':
            data = np.zeros((self.nframe_per_video * len(list_video_temp), self.dim[0], self.dim[1], 6), dtype=np.float32)
            label = np.zeros((self.nframe_per_video * len(list_video_temp), 1), dtype=np.float32)
            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                dysub = np.array(f1[label_key])
                data[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :, :, :] = dXsub
                label[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :] = dysub
            output = (data[:, :, :, :3], data[:, :, :, -3:])
        elif self.temporal == '2DCNN-MT':
            data = np.zeros((self.nframe_per_video * len(list_video_temp), self.dim[0], self.dim[1], 6),
                            dtype=np.float32)
            label_y = np.zeros((self.nframe_per_video * len(list_video_temp), 1), dtype=np.float32)
            label_r = np.zeros((self.nframe_per_video * len(list_video_temp), 1), dtype=np.float32)
            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                drsub = np.array(f1['drsub'])
                dysub = np.array(f1['dysub'])
                data[index * self.nframe_per_video:(index + 1) * self.nframe_per_video, :, :, :] = dXsub
                label_y[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :] = dysub
                label_r[index * self.nframe_per_video:(index + 1) * self.nframe_per_video, :] = drsub
            output = (data[:, :, :, :3], data[:, :, :, -3:])
            label = (label_y, label_r)
        elif self.temporal == 'TSM':
            data = np.zeros((self.nframe_per_video * len(list_video_temp), self.dim[0], self.dim[1], 6), dtype=np.float32)
            label = np.zeros((self.nframe_per_video * len(list_video_temp), 1), dtype=np.float32)
            num_window = int(self.nframe_per_video / self.frame_depth) * len(list_video_temp)
            w = signal.gaussian(4,1)
            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                dysub = np.array(f1[label_key])
                data[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :, :, :] = dXsub
                label[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :] = dysub
            motion_data = data[:, :, :, :3]
            apperance_data = data[:, :, :, -3:]
            apperance_data = np.reshape(apperance_data, (num_window, self.frame_depth, self.dim[0], self.dim[1], 3))
            apperance_data = np.average(apperance_data, axis=1)
            apperance_data = np.repeat(apperance_data[:, np.newaxis, :, :, :], self.frame_depth, axis=1)
            apperance_data = np.reshape(apperance_data, (apperance_data.shape[0] * apperance_data.shape[1],
                                                         apperance_data.shape[2], apperance_data.shape[3],
                                                         apperance_data.shape[4]))
            base_len = self.num_gpu * self.frame_depth
            new_len = base_len * (apperance_data.shape[0] // base_len)
            motion_data = motion_data[:new_len]
            apperance_data = apperance_data[:new_len]
            output = (motion_data, apperance_data)
            label = label[:new_len]
            label = (label)
        elif self.temporal == 'TS_CAN_PEAKDETECTION':
            data = np.zeros((self.nframe_per_video * len(list_video_temp), self.dim[0], self.dim[1], 6), dtype=np.float32)
            label = np.zeros((self.nframe_per_video * len(list_video_temp), 1), dtype=np.float32)
            peaks = np.zeros((self.nframe_per_video * len(list_video_temp), 1), dtype=np.float32)
            num_window = int(self.nframe_per_video / self.frame_depth) * len(list_video_temp)
            w = signal.gaussian(4,1)
            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                dysub = np.array(f1[label_key])
                data[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :, :, :] = dXsub
                label[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :] = dysub
                x=signal.find_peaks(dysub.flatten(),height=2,distance=10)
                psub = np.zeros(dysub.shape)
                psub[x[0]] = 1
                peaks[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :] = signal.lfilter(w,1,psub)
            motion_data = data[:, :, :, :3]
            apperance_data = data[:, :, :, -3:]
            apperance_data = np.reshape(apperance_data, (num_window, self.frame_depth, self.dim[0], self.dim[1], 3))
            apperance_data = np.average(apperance_data, axis=1)
            apperance_data = np.repeat(apperance_data[:, np.newaxis, :, :, :], self.frame_depth, axis=1)
            apperance_data = np.reshape(apperance_data, (apperance_data.shape[0] * apperance_data.shape[1],
                                                         apperance_data.shape[2], apperance_data.shape[3],
                                                         apperance_data.shape[4]))
            base_len = self.num_gpu * self.frame_depth
            new_len = base_len * (apperance_data.shape[0] // base_len)
            motion_data = motion_data[:new_len]
            apperance_data = apperance_data[:new_len]
            output = (motion_data, apperance_data)
            label = label[:new_len]
            peaks = peaks[:new_len]
            label = (label, peaks)
        elif self.temporal == 'TSM-MT' or self.temporal == 'TSM-MT-Dual':
            data = np.zeros((self.nframe_per_video * len(list_video_temp), self.dim[0], self.dim[1], 6), dtype=np.float32)
            label_y = np.zeros((self.nframe_per_video * len(list_video_temp), 1), dtype=np.float32)
            label_r = np.zeros((self.nframe_per_video * len(list_video_temp), 1), dtype=np.float32)
            num_window = int(self.nframe_per_video / self.frame_depth) * len(list_video_temp)
            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                drsub = np.array(f1['drsub'])
                dysub = np.array(f1['dysub'])
                data[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :, :, :] = dXsub
                label_y[index*self.nframe_per_video:(index+1)*self.nframe_per_video, :] = dysub
                label_r[index * self.nframe_per_video:(index + 1) * self.nframe_per_video, :] = drsub
            if np.isnan(label_r).all():
                label_r = np.nan_to_num(label_r)
            motion_data = data[:, :, :, :3]
            apperance_data = data[:, :, :, -3:]
            apperance_data = np.reshape(apperance_data, (num_window, self.frame_depth, self.dim[0], self.dim[1], 3))
            apperance_data = np.average(apperance_data, axis=1)
            apperance_data = np.repeat(apperance_data[:, np.newaxis, :, :, :], self.frame_depth, axis=1)
            apperance_data = np.reshape(apperance_data, (apperance_data.shape[0] * apperance_data.shape[1],
                                                         apperance_data.shape[2], apperance_data.shape[3],
                                                         apperance_data.shape[4]))
            base_len = self.num_gpu * self.frame_depth
            new_len = base_len * (apperance_data.shape[0] // base_len)
            motion_data = motion_data[:new_len]
            apperance_data = apperance_data[:new_len]
            label_y = label_y[:new_len]
            label_r = label_r[:new_len]
            output = (motion_data, apperance_data)
            label = (label_y, label_r)
        elif self.temporal == 'MIX-MT' or self.temporal == 'MIX-MT-Dual':
            num_window = self.nframe_per_video - (self.frame_depth + 1)
            data = np.zeros((num_window*len(list_video_temp), self.dim[0], self.dim[1], self.frame_depth, 6),
                            dtype=np.float32)
            label_y = np.zeros((num_window*len(list_video_temp), self.frame_depth), dtype=np.float32)
            label_r = np.zeros((num_window * len(list_video_temp), self.frame_depth), dtype=np.float32)
            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                drsub = np.array(f1['drsub'])
                dysub = np.array(f1['dysub'])
                tempX = np.array([dXsub[f:f + self.frame_depth, :, :, :] # (169, 10, 36, 36, 6)
                                  for f in range(num_window)])
                tempY_y = np.array([dysub[f:f + self.frame_depth] # (169, 10, 1)
                                  for f in range(num_window)])
                tempY_r = np.array([drsub[f:f + self.frame_depth] # (169, 10, 1)
                                  for f in range(num_window)])
                tempX = np.swapaxes(tempX, 1, 3) # (169, 36, 36, 10, 6)
                tempX = np.swapaxes(tempX, 1, 2) # (169, 36, 36, 10, 6)
                tempY_y = np.reshape(tempY_y, (num_window, self.frame_depth)) # (169, 10)
                tempY_r = np.reshape(tempY_r, (num_window, self.frame_depth))  # (169, 10)
                data[index*num_window:(index+1)*num_window, :, :, :, :] = tempX
                label_y[index*num_window:(index+1)*num_window, :] = tempY_y
                label_r[index * num_window:(index + 1) * num_window, :] = tempY_r
            motion_data = data[:, :, :, :, :3]
            apperance_data = np.average(data[:, :, :, :, -3:], axis=-2)
            output = (motion_data, apperance_data)
            # label = (label_y, label_r)
            label = {"dysub": label_y, "drsub": label_r}
            if self.return_file_names:
                label["file"] = ["{}".format(os.path.splitext(os.path.basename(vid))[0]) for vid in list_video_temp]

        elif self.temporal == 'MIX-MT-Dual-RNN' or self.temporal == 'MIX-MT-Dual-RNN_v2' or self.temporal == 'MIX-MT-Dual-RNN_v3':
            num_window = self.nframe_per_video - (self.frame_depth + 1)
            data = np.zeros((self.batch_size, self.dim[0], self.dim[1], self.frame_depth, 6),
                            dtype=np.float32)
            # data = np.zeros((num_window*len(list_video_temp), self.frame_depth, self.dim[0], self.dim[1], 6),
            #                 dtype=np.float32)
            label = {}
            for sig in self.signals_to_use:
                label[sig] = np.zeros((self.batch_size, self.frame_depth), dtype=np.float32)
                if self.use_raw_signal:
                    label[f"{sig}_raw"] = np.zeros((self.batch_size, self.frame_depth), dtype=np.float32)
                if self.use_second_derivative:
                    label[f"{sig}_SD"] = np.zeros((self.batch_size, self.frame_depth-1), dtype=np.float32)

            if self.seg_masks is not None:
                for mask in self.seg_masks:
                    label[mask] = np.zeros((self.batch_size, self.dim[0], self.dim[1]), dtype=np.float32)
                    if mask == "skin_mask":
                        label[mask] = np.zeros((self.batch_size, int(self.dim[0]/2), int(self.dim[1]/2)), dtype=np.float32)
            
            if self.return_file_names:
                label["file"] = []

            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                # get the start position within the file
                # if split_iter, index differently into file
                if self.split_iter:
                    f = (indexes[index] // len(self.paths_of_videos)) * self.step_size
                else:
                    # iterate in order
                    f = (indexes[index] % (self.num_window)) * self.step_size

                # if return_file_names flag is set, also return the file name for each video in the batch
                if self.return_file_names:
                    label["file"].append("{}_{}".format(os.path.splitext(os.path.basename(temp_path))[0], f))

                # for each signal, get frames that correspond to the time window of interest
                for sig in self.signals_to_use:
                    try:
                        temp_y = np.array(f1[sig][f:f + self.frame_depth])
                        label[sig][index] = temp_y[:, 0]
                        # if including the raw signal, cumulative sum the first derivative
                        if self.use_raw_signal:
                            label[f"{sig}_raw"][index] = np.cumsum(temp_y[:, 0], axis=0)
                            # standardize signal to reduce magnitude relative to other signals 
                            label[f"{sig}_raw"][index] = (label[f"{sig}_raw"][index] - np.mean(label[f"{sig}_raw"][index])) / np.std(label[f"{sig}_raw"][index])
                        # if including the second derivative, calculate from the first derivative
                        if self.use_second_derivative:
                            label[f"{sig}_SD"][index] = np.diff(temp_y[:, 0], axis=0)
                    except ValueError as e:
                        print(e)
                        print("guilty file: ", temp_path)
                        print("label shape:", temp_y.shape)
                        # if there are not enough frames for entire window,
                        # pad remaining frames with zeros
                        label[sig][index][:temp_y.shape[0]] = temp_y[:, 0]
                # print below helps for debugging iteration
                # print("indexes[index]: {:.2f} path: {} i: {} f: {}".format(indexes[index], os.path.basename(temp_path), index, f))
                tempX = np.array([dXsub[f:f + self.frame_depth, :, :, :]]) # (1, 10, 36, 36, 6)
                # if data augmentation flag set, apply augmentation functions 
                if self.data_aug:
                    tempX = self.vid_aug_func(tempX)
                tempX = np.swapaxes(tempX, 1, 3) # (1, 36, 36, 10, 6)
                tempX = np.swapaxes(tempX, 1, 2) # (1, 36, 36, 10, 6)
                # if there are not enough frames for entire window,
                # pad remaining frames with zeros
                try:
                    data[index, :, :, :, :] = tempX
                except ValueError as e:
                    print(e)
                    print("Data shape:", tempX.shape)
                    data[index, :, :, :tempX.shape[3], :] = tempX
                # optionally load segmentation masks
                if self.seg_masks is not None:
                    # for each segmentation mask, take mean of segmentation 
                    # mask over all frames to match appearance branch image
                    for mask in self.seg_masks:
                        # for skin segmentation mask, we need to resize mask 
                        # to smaller image as the size reduces by factor of 2
                        if mask == "skin_mask":
                            label[mask][index] = resize(np.mean(np.transpose(f1[mask])[f:f + self.frame_depth], axis=0), output_shape=(18, 18))
                        else:
                            label[mask][index] = np.mean(np.transpose(f1[mask])[f:f + self.frame_depth], axis=0)
                # if the .mat file contains ground truth ECG data and we are
                # evaluating the model, additionally return the raw ECG data
                # and the beat measurements
                if self.return_file_names and 'ecg30' in f1.keys():
                    if "ecg30" not in label.keys():
                        label['ecg30'] = [np.array(f1['ecg30'][f:f + self.frame_depth])]
                        label['ecgBeats'] = [np.squeeze(f1['ecgBeats'])]
                    else:
                        label['ecg30'].append(np.array(f1['ecg30'][f:f + self.frame_depth]))
                        label['ecgBeats'].append(np.squeeze(f1['ecgBeats']))

            motion_data = data[:, :, :, :, :3]
            apperance_data = np.average(data[:, :, :, :, -3:], axis=-2)
            output = (motion_data, apperance_data)
            # if return_file_names flag is set, also return the file name for each video in the batch
            # if self.return_file_names:
            #     label["file"] = ["{}_{}".format(os.path.splitext(os.path.basename(vid))[0], f) for vid in list_video_temp]
        elif self.temporal == 'MIX-MT-Dual-RNN_MD':
            num_window = self.nframe_per_video - (self.frame_depth + 1)
            data = np.zeros((self.batch_size, self.dim[0], self.dim[1], self.frame_depth, 6),
                            dtype=np.float32)
            # data = np.zeros((num_window*len(list_video_temp), self.frame_depth, self.dim[0], self.dim[1], 6),
            #                 dtype=np.float32)
            label = {}
            for sig in self.signals_to_use:
                label[sig] = np.zeros((self.batch_size, self.frame_depth, 2), dtype=np.float32)
                if self.use_raw_signal:
                    label[f"{sig}_raw"] = np.zeros((self.batch_size, self.frame_depth, 2), dtype=np.float32)
                if self.use_second_derivative:
                    label[f"{sig}_SD"] = np.zeros((self.batch_size, self.frame_depth-1, 2), dtype=np.float32)

            if self.seg_masks is not None:
                for mask in self.seg_masks:
                    label[mask] = np.zeros((self.batch_size, self.dim[0], self.dim[1]), dtype=np.float32)
                    if mask == "skin_mask":
                        label[mask] = np.zeros((self.batch_size, int(self.dim[0]/2), int(self.dim[1]/2)), dtype=np.float32)
            
            if self.return_file_names:
                label["file"] = []

            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                # get the start position within the file
                # if split_iter, index differently into file
                if self.split_iter:
                    f = (indexes[index] // len(self.paths_of_videos)) * self.step_size
                else:
                    # iterate in order
                    f = (indexes[index] % (self.num_window)) * self.step_size

                # if return_file_names flag is set, also return the file name for each video in the batch
                if self.return_file_names:
                    label["file"].append("{}_{}".format(os.path.splitext(os.path.basename(temp_path))[0], f))

                # for each signal, get frames that correspond to the time window of interest
                for sig in self.signals_to_use:
                    try:
                        temp_y = np.array(f1[sig][f:f + self.frame_depth])
                        label[sig][index][:, 0] = temp_y[:, 0]
                        if sig == "bpraw":
                            label[sig][index][:, 0] = (label[sig][index][:, 0] - 75.) / 15.
                        zero_crossings = get_zero_crossings(temp_y[:, 0])
                        np.put(label[sig][index][:, 1], zero_crossings, 1.)
                        # if including the raw signal, cumulative sum the first derivative
                        if self.use_raw_signal:
                            np.put(label[f"{sig}_raw"][index][:, 1], zero_crossings, 1.)
                            label[f"{sig}_raw"][index][:, 0] = np.cumsum(temp_y[:, 0], axis=0)
                            import matplotlib.pyplot as plt
                            # standardize signal to reduce magnitude relative to other signals 
                            label[f"{sig}_raw"][index] = (label[f"{sig}_raw"][index] - np.mean(label[f"{sig}_raw"][index])) / np.std(label[f"{sig}_raw"][index])
                        # if including the second derivative, calculate from the first derivative
                        if self.use_second_derivative:
                            label[f"{sig}_SD"][index][:, 0] = np.diff(temp_y[:, 0], axis=0)
                            # calculate systolic peak times
                            sys_upstroke = calc_PPG_peaks(-label[f"{sig}_SD"][index][:, 0], height=(0.5,))
                            # calculate dicrotic notch times 
                            dicrotic_peaks = calc_PPG_peaks(label[f"{sig}_SD"][index][:, 0], distance=3, height=(0.,1))
                            peak_indices = np.array(list(sys_upstroke) + list(dicrotic_peaks), dtype=np.int)
                            np.put(label[f"{sig}_SD"][index][:, 1], peak_indices, 1.)
                            # standardize signal to reduce magnitude relative to other signals 
                            label[f"{sig}_SD"][index] = (label[f"{sig}_SD"][index] - np.mean(label[f"{sig}_SD"][index])) / np.std(label[f"{sig}_SD"][index])
                    except ValueError as e:
                        print(e)
                        print("guilty file: ", temp_path)
                        print("label shape:", temp_y.shape)
                        # if there are not enough frames for entire window,
                        # pad remaining frames with zeros
                        label[sig][index][:temp_y.shape[0]] = temp_y[:, 0]
                # print below helps for debugging iteration
                # print("indexes[index]: {:.2f} path: {} i: {} f: {}".format(indexes[index], os.path.basename(temp_path), index, f))
                tempX = np.array([dXsub[f:f + self.frame_depth, :, :, :]]) # (1, 10, 36, 36, 6)
                # if data augmentation flag set, apply augmentation functions 
                if self.data_aug:
                    tempX = self.vid_aug_func(tempX)
                tempX = np.swapaxes(tempX, 1, 3) # (1, 36, 36, 10, 6)
                tempX = np.swapaxes(tempX, 1, 2) # (1, 36, 36, 10, 6)
                # if there are not enough frames for entire window,
                # pad remaining frames with zeros
                try:
                    data[index, :, :, :, :] = tempX
                except ValueError as e:
                    print(e)
                    print("Data shape:", tempX.shape)
                    data[index, :, :, :tempX.shape[3], :] = tempX
                # optionally load segmentation masks
                if self.seg_masks is not None:
                    # for each segmentation mask, take mean of segmentation 
                    # mask over all frames to match appearance branch image
                    for mask in self.seg_masks:
                        # for skin segmentation mask, we need to resize mask 
                        # to smaller image as the size reduces by factor of 2
                        if mask == "skin_mask":
                            label[mask][index] = resize(np.mean(np.transpose(f1[mask])[f:f + self.frame_depth], axis=0), output_shape=(18, 18))
                        else:
                            label[mask][index] = np.mean(np.transpose(f1[mask])[f:f + self.frame_depth], axis=0)
                # if the .mat file contains ground truth ECG data and we are
                # evaluating the model, additionally return the raw ECG data
                # and the beat measurements
                if self.return_file_names and 'ecg30' in f1.keys():
                    if "ecg30" not in label.keys():
                        label['ecg30'] = [np.array(f1['ecg30'][f:f + self.frame_depth])]
                        label['ecgBeats'] = [np.squeeze(f1['ecgBeats'])]
                    else:
                        label['ecg30'].append(np.array(f1['ecg30'][f:f + self.frame_depth]))
                        label['ecgBeats'].append(np.squeeze(f1['ecgBeats']))

            # zeroth derivative is raw video data
            zeroth_deriv_data = data[:, :, :, :, -3:]
            # import matplotlib.pyplot as plt
            # plt.imshow(zeroth_deriv_data[0, :, :, 0, -3:])
            # plt.savefig("test.png")
            # plt.close()
            # first derivative is normalized diff between consec raw frames
            first_deriv_data = data[:, :, :, :, :3]
            # second derivative is normalized diff between consec first deriv frames
            second_deriv_data = np.diff(first_deriv_data, axis=3)
            output = (zeroth_deriv_data, first_deriv_data, second_deriv_data)
            if not self.use_first_derivative:
                for sig in self.signals_to_use:
                    del label[sig]

        elif self.temporal == 'ViViT' or self.temporal == 'MIX-MT-Dual-RNN_v4':
            num_window = self.nframe_per_video - (self.frame_depth + 1)
            data = np.zeros((self.batch_size, self.frame_depth, self.dim[0], self.dim[1], 6),
                            dtype=np.float32)
            # data = np.zeros((num_window*len(list_video_temp), self.frame_depth, self.dim[0], self.dim[1], 6),
            #                 dtype=np.float32)
            label = {}
            for sig in self.signals_to_use:
                label[sig] = np.zeros((self.batch_size, self.frame_depth), dtype=np.float32)

            if self.seg_masks is not None:
                for mask in self.seg_masks:
                    label[mask] = np.zeros((self.batch_size, self.dim[0], self.dim[1]), dtype=np.float32)
                    if mask == "skin_mask":
                        label[mask] = np.zeros((self.batch_size, int(self.dim[0]/2), int(self.dim[1]/2)), dtype=np.float32)

            if self.return_file_names:
                label["file"] = []

            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                # get the start position within the file
                # if split_iter, index differently into file
                if self.split_iter:
                    f = (indexes[index] // len(self.paths_of_videos)) * self.step_size
                else:
                    # iterate in order
                    f = (indexes[index] % (self.num_window)) * self.step_size

                # if return_file_names flag is set, also return the file name for each video in the batch
                if self.return_file_names:
                    label["file"].append("{}_{}".format(os.path.splitext(os.path.basename(temp_path))[0], f))

                # for each signal, get frames that correspond to the time window of interest
                for sig in self.signals_to_use:
                    try:
                        temp_y = np.array(f1[sig][f:f + self.frame_depth])
                        label[sig][index] = temp_y[:, 0]
                    except ValueError as e:
                        print(e)
                        print("signal {} (at index {} - {}): ".format(sig, index, f))
                        print("guilty file: ", temp_path)
                        exit()
                # print below helps for debugging iteration
                # print("indexes[index]: {:.2f} path: {} i: {} f: {}".format(indexes[index], os.path.basename(temp_path), index, f))
                tempX = np.array([dXsub[f:f + self.frame_depth, :, :, :]]) # (1, 10, 36, 36, 6)
                # if data augmentation flag set, apply augmentation functions 
                if self.data_aug:
                    tempX = self.vid_aug_func(tempX)
                # tempX = np.swapaxes(tempX, 1, 3) # (1, 36, 36, 10, 6)
                # tempX = np.swapaxes(tempX, 1, 2) # (1, 36, 36, 10, 6)
                data[index, :, :, :, :] = tempX
                # optionally load segmentation masks
                if self.seg_masks is not None:
                    # for each segmentation mask, take mean of segmentation 
                    # mask over all frames to match appearance branch image
                    for mask in self.seg_masks:
                        # for skin segmentation mask, we need to resize mask 
                        # to smaller image as the size reduces by factor of 2
                        if mask == "skin_mask":
                            label[mask][index] = resize(np.mean(np.transpose(f1[mask])[f:f + self.frame_depth], axis=0), output_shape=(18, 18))
                        else:
                            label[mask][index] = np.mean(np.transpose(f1[mask])[f:f + self.frame_depth], axis=0)
                # if the .mat file contains ground truth ECG data and we are
                # evaluating the model, additionally return the raw ECG data
                # and the beat measurements
                if self.return_file_names and 'ecg30' in f1.keys():
                    if "ecg30" not in label.keys():
                        label['ecg30'] = [np.array(f1['ecg30'][f:f + self.frame_depth])]
                        label['ecgBeats'] = [np.squeeze(f1['ecgBeats'])]
                    else:
                        label['ecg30'].append(np.array(f1['ecg30'][f:f + self.frame_depth]))
                        label['ecgBeats'].append(np.squeeze(f1['ecgBeats']))

            motion_data = data[:, :, :, :, :3]
            # add in additional dimension for time
            apperance_data = np.expand_dims(np.average(data[:, :, :, :, -3:], axis=1), 1)
            # concatenate appearance frame as additional time step
            output = np.concatenate((motion_data, apperance_data), axis=1)

        elif self.temporal == 'MIX':
            num_window = self.nframe_per_video - (self.frame_depth + 1)
            data = np.zeros((num_window*len(list_video_temp), self.dim[0], self.dim[1], self.frame_depth, 6),
                            dtype=np.float32)
            label = np.zeros((num_window*len(list_video_temp), self.frame_depth), dtype=np.float32)
            for index, temp_path in enumerate(list_video_temp):
                f1, dXsub = self.data_load_func(temp_path)
                dysub = np.array(f1[label_key])
                tempX = np.array([dXsub[f:f + self.frame_depth, :, :, :] # (169, 10, 36, 36, 6)
                                  for f in range(num_window)])
                tempY = np.array([dysub[f:f + self.frame_depth] # (169, 10, 1)
                                  for f in range(num_window)])
                tempX = np.swapaxes(tempX, 1, 3) # (169, 36, 36, 10, 6)
                tempX = np.swapaxes(tempX, 1, 2) # (169, 36, 36, 10, 6)
                tempY = np.reshape(tempY, (num_window, self.frame_depth)) # (169, 10)
                data[index*num_window:(index+1)*num_window, :, :, :, :] = tempX
                label[index*num_window:(index+1)*num_window, :] = tempY
            motion_data = data[:, :, :, :, :3]
            apperance_data = np.average(data[:, :, :, :, -3:], axis=-2)
            output = (motion_data, apperance_data)
        else:
            raise ValueError('Unsupported Model!')

        # if self.data_aug == 1:
        #     output = np.concatenate((output[0], output[1]), axis=-1)
        #     if len(output.shape) == 4:
        #         output = np.array([self.data_aug_func(data) for data in output])
        #         motion_data = output[:, :, :, :3]
        #         apperance_data = output[:, :, :, 3:]
        #     else:
        #         raise('Unsupported!')
        #     output = (motion_data, apperance_data)
        return output, label
