'''Trains CNN iPPG DeepPhys Training:
'''
# %%
from __future__ import print_function

import argparse
import itertools
import json
import os
import sys
import datetime

import numpy as np
import pandas as pd
import scipy.io
import tensorflow as tf
from einops.layers.tensorflow import Rearrange
from tensorflow_addons.layers import GELU

from data_generator import DataGenerator
from metric import calculate_metric
from model import HeartBeat, DeepPhy, DeepPhy_3DCNN, DeepPhy_3DCNN_MT, DeepPhys_2DCNN_MT, \
    Hybrid_CAN, Hybrid_CAN_MT, Hybrid_CAN_MT_Dual, TS_CAN, MTTS_CAN, MTTS_CAN_Dual, Hybrid_CAN_MT_Dual_RNN, Hybrid_CAN_MT_Dual_RNN_v2, Hybrid_CAN_MT_Dual_RNN_v3, Hybrid_CAN_MT_Dual_RNN_v4, Hybrid_CAN_MT_Dual_RNN_MD, TS_CAN_PEAKDETECTION, first_and_second_derivative_loss, first_derivative_loss, second_derivative_loss
from model import Attention_mask, dice_coef_loss, first_and_second_derivative_loss, first_derivative_loss, second_derivative_loss, second_derivative_peak_loss
from pre_process import get_nframe_video, retrive_labels, read_from_txt
sys.path.append("ViViT-tensorflow")
# from vivit import ViViT

np.random.seed(100)  # for reproducibility
print(tf.__version__)
print("GPU available: ", tf.config.list_physical_devices('GPU'))

tf.keras.backend.clear_session()
# %%
parser = argparse.ArgumentParser()
# data I/O
parser.add_argument('-exp', '--exp_name', type=str, default='Test',
                    help='experiment name')
parser.add_argument('-i', '--data_dir', type=str,
                    default='C:\Data\ippg\\', help='Location for the dataset')
parser.add_argument('-img', '--img_size', type=int, default=36, help='img_size')
parser.add_argument('-crp_img', '--cropped_size', type=int, default=36, help='img_size')
parser.add_argument('-tr_data', '--tr_dataset', type=str, default='AFRL', help='training dataset name')
parser.add_argument('-ts_data', '--ts_dataset', type=str, default='AFRL', help='test dataset name')
parser.add_argument('-tr_txt', '--train_txt', type=str, default='./filelists/Train.txt', help='train file')
parser.add_argument('-ts_txt', '--test_txt', type=str, default='./filelists/Test.txt', help='test file')
parser.add_argument('-o', '--save_dir', type=str, default='./rPPG-checkpoints',
                    help='Location for parameter checkpoints and samples')
parser.add_argument('-a', '--nb_filters1', type=int, default=32,
                    help='number of convolutional filters to use')
parser.add_argument('-b', '--nb_filters2', type=int, default=64,
                    help='number of convolutional filters to use')
parser.add_argument('-c', '--dropout_rate1', type=float, default=0.25,
                    help='dropout rates')
parser.add_argument('-d', '--dropout_rate2', type=float, default=0.5,
                    help='dropout rates')
parser.add_argument('-l', '--lr', type=float, default=0.001,
                            help='learning rate')
parser.add_argument('-e', '--nb_dense', type=int, default=128,
                    help='number of dense units')
parser.add_argument('-f', '--cv_split', type=int, default=0,
                    help='cv_split')
parser.add_argument('-g', '--nb_epoch', type=int, default=8,
                    help='nb_epoch')
parser.add_argument('-ie', '--initial_epoch', type=int, default=0, 
                    help="Initial epoch when starting model training")
parser.add_argument('-t', '--nb_task', type=int, default=12,
                    help='nb_task')
parser.add_argument('-x', '--batch_size', type=int, default=8,
                    help='batch')
parser.add_argument('-fd', '--frame_depth', type=int, default=10,
                    help='frame_depth for 3DCNN')
parser.add_argument('-temp', '--temporal', type=str, default='TS_CAN_PEAKDETECTION',
                    help='3DCNN, 2DCNN or mix')
parser.add_argument('-save', '--save_all', type=int, default=1,
                    help='save all or not')
parser.add_argument('-resp', '--respiration', type=int, default=0,
                    help='train with resp or not')
parser.add_argument('-shuf', '--shuffle', type=str, default=True,
                    help='shuffle samples')
parser.add_argument('-da', '--data_aug', type=int, default=0,
                    help='data augmentation')
parser.add_argument('-ds', '--data_fs', type=int, default=30,
                    help='data frequency')
parser.add_argument('-dw', '--eval_window', type=int, default=360,
                    help='data frequency')
parser.add_argument('-ss', '--step_size_train', type=int, default=1,
                    help='number of frames in between windows in video segment during training')
parser.add_argument('-sst', '--step_size_test', type=int, default=None,
                    help='number of frames in between windows in video segment during testing')
parser.add_argument('-trs', '--signals_to_use_train', nargs="+", default=["dysub", "drsub"],
                    help='List of target signals to use for training - note for backwards compatability, use ppg->dysub and resp->drsub')
parser.add_argument('-tss', '--signals_to_use_test', nargs="+", default=["dysub", "drsub"],
                    help='List of target signals to use for testing - note for backwards compatability, use ppg->dysub and resp->drsub')
parser.add_argument('--second_derivative_loss', action='store_true', help='Use first and second derivative loss')
parser.add_argument('--predict_first_derivative', action='store_true', help='Add additional output target signal for second derivative')
parser.add_argument('--predict_second_derivative', action='store_true', help='Add additional output target signal for second derivative')
parser.add_argument('--predict_raw_signal', action='store_true', help='Add additional output target signal for raw output (0th derivative')
parser.add_argument('--use_second_derivative_frames', action='store_true', help='Use second derivative frames as input')
parser.add_argument('--use_second_derivative_frames_only', action='store_true', help='Use ONLY second derivative frames as input')
parser.add_argument('--add_first_derivative_loss', action='store_true', help='Add additional loss penalty for first derivative signal')
parser.add_argument('--add_second_derivative_loss', action='store_true', help='Add additional loss penalty for second derivative signal')
parser.add_argument('-m', '--model_file', type=str, default=None,
                    help="Path to pre-trained model for continued training")

args = parser.parse_args()
print('input args:\n', json.dumps(vars(args), indent=4, separators=(',', ':')))  # pretty print args

# if test step size is not supplied, default should be the frame depth 
# i.e. no overlapping windows in test output
if not args.step_size_test:
    args.step_size_test = args.frame_depth

# %% Spliting Data

print('Spliting Data...')
subNum = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 25, 26, 27])
taskList = list(range(1, args.nb_task+1))

# %% Training


def train(args, cv_split, img_rows, img_cols):
    print('================================')
    print('Train...')

    input_shape = (img_rows, img_cols, 3)

    # Reading Data
    path_of_video_tr, path_of_video_test = read_from_txt(args.train_txt, args.test_txt, args.data_dir)
    nframe_per_video_tr = get_nframe_video(path_of_video_tr[0], dataset=args.tr_dataset)
    nframe_per_video_ts = get_nframe_video(path_of_video_test[0], dataset=args.ts_dataset)
    # path_of_video_tr = path_of_video_tr[:50]
    # path_of_video_test = path_of_video_test[:50]

    print('sample path: ', path_of_video_tr[0])
    print('Train Length: ', len(path_of_video_tr))
    print('Test Length: ', len(path_of_video_test))
    print('nframe_per_video_tr', nframe_per_video_tr)
    print('nframe_per_video_ts', nframe_per_video_ts)

    strategy = tf.distribute.MirroredStrategy()
    print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

    # make sure the command line arguments for target signals to use are valid 
    valid_signal_list = ["dysub", "drsub", "ecg", "abp", "r", "g", "b", "pitch", "yaw", "roll", 
    "Eyebrows_Raised", "Eyebrows_Frown", "Eyes_Squint", "Smile_Lips_Closed", 
    "Frown_Left", "Frown_Right", "Chin_Raised", "Kiss_Left", "Kiss_Right", 
    "Mouth_Little_Opened", "Mouth_Large_Opened", "blink", "bpraw"]
    signals_to_use_train = args.signals_to_use_train
    signals_to_use_test = args.signals_to_use_test
    print("Using these target signals for training: {}".format(signals_to_use_train))
    for s in signals_to_use_train:
        if "_raw" in s or "_SD" in s:
            s = s.split("_")[0]
        assert s in valid_signal_list, f"Training target {s} is not the the valid_signal_list"
    print("Using these target signals for testing: {}".format(signals_to_use_test))
    for s in signals_to_use_test:
        if "_raw" in s or "_SD" in s:
            s = s.split("_")[0]
        assert s in valid_signal_list, f"Testing target {s} is not the the valid_signal_list"

    with strategy.scope():
        if strategy.num_replicas_in_sync == 4:
            print("Using 4 GPUs for training")
            print('self.temporal: ', args.temporal)
            if args.temporal == '2DCNN' or args.temporal == '2DCNN-MT':
                args.batch_size = 32
            elif args.temporal == '3DCNN' or args.temporal == '3DCNN-MT':
                args.batch_size = 12
            elif args.temporal == 'TSM' or args.temporal == 'TSM-MT' or args.temporal == 'TSM-MT-Dual':
                args.batch_size = 32
            elif args.temporal == 'MIX' or args.temporal == 'MIX-MT':
                args.batch_size = 16
            elif args.temporal == 'MIX-MT-Dual':
                args.batch_size = 8
            elif args.temporal.startswith('MIX-MT-Dual-RNN'):
                args.batch_size = 8
            elif args.temporal.startswith('ViViT'):
                args.batch_size = 32
            else:
                raise ValueError('Unsupported Model Type!')
        elif  strategy.num_replicas_in_sync == 1:
            print("Using 1 GPUs for training")
            if args.temporal == '3DCNN_MT_SlowFast' or args.temporal == '3DCNN-MT':
                args.batch_size = 1
        elif strategy.num_replicas_in_sync == 8:
            print('Using 8 GPUs for training!')
            args.batch_size = args.batch_size * 2
        elif strategy.num_replicas_in_sync == 2:
            args.batch_size = args.batch_size // 2
        elif strategy.num_replicas_in_sync == 1:
            args.batch_size = args.batch_size // 4
        else:
            raise Exception('Only supporting 4 GPUs or 8 GPUs now. Please adjust learning rate in the training script!')

        # if model file supplied, load the existing model
        if args.model_file:
            print("Loading existing model from file:", args.model_file)
            model_file = args.model_file
            model = tf.keras.models.load_model(model_file, custom_objects={
                'Attention_mask': Attention_mask, 
                'dice_coef_loss': dice_coef_loss,
                'Rearrange': Rearrange, 
                'GELU': GELU, 
                'first_and_second_derivative_loss': first_and_second_derivative_loss, 
                'my_loss_function': first_and_second_derivative_loss, 
                'first_derivative_loss': first_derivative_loss,
                'second_derivative_loss': second_derivative_loss})
        elif args.temporal == '2DCNN':
            model = DeepPhy(args.nb_filters1, args.nb_filters2, input_shape, dropout_rate1=args.dropout_rate1,
                            dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == '2DCNN-MT':
            model = DeepPhys_2DCNN_MT(args.nb_filters1, args.nb_filters2, input_shape, dropout_rate1=args.dropout_rate1,
                            dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == '3DCNN':
            print('Using 3DCNN!')
            input_shape = (img_rows, img_cols, args.frame_depth, 3)
            model = DeepPhy_3DCNN(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape,
                                  dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == '3DCNN-MT':
            print('Using 3DCNN Multi-Task!')
            input_shape = (img_rows, img_cols, args.frame_depth, 3)
            model = DeepPhy_3DCNN_MT(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape,
                                  dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == 'TSM':
            print('Using TS-CAN')
            input_shape = (img_rows, img_cols, 3)
            model = TS_CAN(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape,
                                 dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == 'TS_CAN_PEAKDETECTION':
            print('Using TS_CAN_PEAKDETECTION')
            input_shape = (img_rows, img_cols, 3)
            model = TS_CAN_PEAKDETECTION(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape,
                                 dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == 'TSM-MT':
            print('Using MTTS-CAN!')
            input_shape = (img_rows, img_cols, 3)
            model = MTTS_CAN(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape,
                                 dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == 'TSM-MT-Dual':
            print('Using MTTS-CAN-Dual Attention!')
            input_shape = (img_rows, img_cols, 3)
            model = MTTS_CAN_Dual(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape,
                                 dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == 'MIX':  # Mix
            print('Using Mix!!')
            input_shape_motion = (img_rows, img_cols, args.frame_depth, 3)
            input_shape_app = (img_rows, img_cols, 3)
            model = Hybrid_CAN(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape_motion, input_shape_app,
                                dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2,
                                nb_dense=args.nb_dense)
        elif args.temporal == 'MIX-MT':
            print('Using MIX-MT!')
            input_shape_motion = (img_rows, img_cols, args.frame_depth, 3)
            input_shape_app = (img_rows, img_cols, 3)
            model = Hybrid_CAN_MT(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape_motion, input_shape_app,
                                dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == 'MIX-MT-Dual':
            print('Using MIX-MT-Dual Attention!')
            input_shape_motion = (img_rows, img_cols, args.frame_depth, 3)
            input_shape_app = (img_rows, img_cols, 3)
            model = Hybrid_CAN_MT_Dual(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape_motion, input_shape_app,
                                dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense)
        elif args.temporal == 'MIX-MT-Dual-RNN':
            print('Using MIX-MT-Dual-RNN Attention!')
            input_shape_motion = (img_rows, img_cols, args.frame_depth, 3)
            # input_shape_motion = (args.frame_depth, img_rows, img_cols, 3)
            input_shape_app = (img_rows, img_cols, 3)
            model = Hybrid_CAN_MT_Dual_RNN(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape_motion, input_shape_app,
                                dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense, pool_size_1=(2, 2, 1), 
                                use_second_derivative=args.predict_second_derivative, 
                                use_raw_signal=args.predict_raw_signal,
                                target_signals=signals_to_use_train)
        elif args.temporal == 'MIX-MT-Dual-RNN_v2':
            print('Using MIX-MT-Dual-RNN Attention v2!')
            input_shape_motion = (img_rows, img_cols, args.frame_depth, 3)
            # input_shape_motion = (args.frame_depth, img_rows, img_cols, 3)
            input_shape_app = (img_rows, img_cols, 3)
            model = Hybrid_CAN_MT_Dual_RNN_v2(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape_motion, input_shape_app,
                                dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense, pool_size_1=(2, 2, 1),)
        elif args.temporal == 'MIX-MT-Dual-RNN_v3':
            print('Using MIX-MT-Dual-RNN Attention v3!')
            input_shape_motion = (img_rows, img_cols, args.frame_depth, 3)
            # input_shape_motion = (args.frame_depth, img_rows, img_cols, 3)
            input_shape_app = (img_rows, img_cols, 3)
            model = Hybrid_CAN_MT_Dual_RNN_v3(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape_motion, input_shape_app,
                                dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense, pool_size_1=(2, 2, 1),)
        elif args.temporal == 'MIX-MT-Dual-RNN_v4':
            print('Using MIX-MT-Dual-RNN Attention!')
            input_shape_motion = (args.frame_depth, img_rows, img_cols, 3)
            # input_shape_motion = (args.frame_depth, img_rows, img_cols, 3)
            input_shape_app = (img_rows, img_cols, 3)
            model = Hybrid_CAN_MT_Dual_RNN_v4(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape_motion, input_shape_app,
                                dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense, pool_size_1=(2, 2, 1), 
                                use_second_derivative=args.predict_second_derivative, 
                                use_raw_signal=args.predict_raw_signal,
                                target_signals=signals_to_use_train)
        # elif args.temporal == "ViViT":
        #     print('Using ViViT!')
        #     input_shape_motion = (args.frame_depth, img_rows, img_cols, 3)
        #     # input_shape_motion = (args.frame_depth, img_rows, img_cols, 3)
        #     input_shape_app = (img_rows, img_cols, 3)
        #     model = ViViT(image_size=img_rows, 
        #         patch_size=int(img_rows/6), 
        #         num_classes=len(signals_to_use_train), 
        #         num_frames=args.frame_depth+1, 
        #         batch_size=args.batch_size, 
        #         dim=128, depth=10, heads=8, pool='time', 
        #         in_channels=3, dim_head=64, 
        #         dropout=args.dropout_rate1, emb_dropout=args.dropout_rate2, 
        #         scale_dim=4, output_names=signals_to_use_train)
        #     # model.build(input_shape=(args.batch_size, args.frame_depth+1, img_rows, img_cols, 3))
        elif args.temporal == 'MIX-MT-Dual-RNN_MD':
            print('Using MIX-MT-Dual-RNN MD Attention!')
            input_shape_app = (img_rows, img_cols, args.frame_depth, 3)
            input_shape_motion = (img_rows, img_cols, args.frame_depth, 3)
            input_shape_accel = (img_rows, img_cols, args.frame_depth-1, 3)
            model = Hybrid_CAN_MT_Dual_RNN_MD(args.frame_depth, args.nb_filters1, args.nb_filters2, input_shape_1=input_shape_app, input_shape_2=input_shape_motion, input_shape_3=input_shape_accel,
                                dropout_rate1=args.dropout_rate1, dropout_rate2=args.dropout_rate2, nb_dense=args.nb_dense, pool_size_1=(2, 2, 1), 
                                use_second_derivative_frames=args.use_second_derivative_frames,
                                use_second_derivative_frames_only=args.use_second_derivative_frames_only, 
                                predict_raw_signal=args.predict_raw_signal, 
                                predict_first_derivative=args.predict_first_derivative,
                                predict_second_derivative=args.predict_second_derivative,
                                target_signals=signals_to_use_train)
        else:
            raise ValueError('Unsupported Model Type!')

        seg_masks = None
        #optimizer = tf.keras.optimizers.Adadelta(learning_rate=args.lr)
        optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr) 
        if args.temporal == 'TSM-MT' or args.temporal == 'MIX-MT' or args.temporal == '3DCNN-MT' or \
                args.temporal == '2DCNN-MT' or args.temporal == 'MIX-MT-Dual' or args.temporal == 'TSM-MT-Dual' \
                or args.temporal == 'TS_CAN_PEAKDETECTION' or args.temporal == 'MIX-MT-Dual-RNN' or args.temporal == 'MIX-MT-Dual-RNN_v4':
            # losses = {"pulse": "mean_squared_error", "resp": "mean_squared_error"}
            # loss_weights = {"pulse": 1.0, "resp": 1.0}
            losses = {sig: "mean_squared_error" for sig in signals_to_use_train}
            if args.predict_second_derivative:
                for sig in signals_to_use_train:
                    losses[f"{sig}_SD"] = "mean_squared_error"
            if args.predict_raw_signal:
                for sig in signals_to_use_train:
                    losses[f"{sig}_raw"] = "mean_squared_error"
            loss_weights = {sig: 1.0 for sig in losses.keys()}
            model.compile(loss=losses, loss_weights=loss_weights, optimizer=optimizer)
        elif args.temporal == 'MIX-MT-Dual-RNN_v3':
            print("Including DICE coef loss")
            seg_masks = ["combined_mask", "skin_mask"]
            losses = {sig: "mean_squared_error" for sig in signals_to_use_train}
            losses['combined_mask'] = "mean_squared_error"
            losses['skin_mask'] = "mean_squared_error"
            # losses['combined_mask'] = dice_coef_loss
            # losses['skin_mask'] = dice_coef_loss
            loss_weights = {sig: 1.0 for sig in signals_to_use_train}
            model.compile(loss=losses, loss_weights=loss_weights, optimizer=optimizer)
        elif args.second_derivative_loss:
            print("Using modified loss")
            losses = {sig: first_and_second_derivative_loss for sig in signals_to_use_train}
            loss_weights = {sig: 1.0 for sig in signals_to_use_train}
            model.compile(loss=losses, loss_weights=loss_weights, optimizer=optimizer)
        elif args.temporal == 'MIX-MT-Dual-RNN_MD':
            print("Using custom loss for multi-derivative model")
            losses = {}
            if args.predict_raw_signal:
                for sig in signals_to_use_train:
                    losses[f"{sig}_raw"] = "mean_squared_error"
            if args.predict_first_derivative:
                for sig in signals_to_use_train:
                    losses[f"{sig}"] = "mean_squared_error"
            if args.predict_second_derivative:
                for sig in signals_to_use_train:
                    losses[f"{sig}_SD"] = "mean_squared_error"
            # losses["dysub_raw"] = second_derivative_peak_loss
            # if modified loss selected with command line flags, change function
            if args.add_first_derivative_loss:
                print("Using modified loss for first derivative signal")
                for sig in signals_to_use_train:
                    losses[f"{sig}"] = second_derivative_peak_loss
            if args.add_second_derivative_loss:
                print("Using modified loss for second derivative signal")
                for sig in signals_to_use_train:
                    losses[f"{sig}_SD"] = second_derivative_peak_loss

            model.compile(loss=losses, optimizer=optimizer)
        else:
            model.compile(loss='mean_squared_error', optimizer=optimizer)
        print('learning rate: ', args.lr)
        print(model.summary())

        #%% Create data genener
        training_generator = DataGenerator(path_of_video_tr, nframe_per_video_tr, (args.img_size, args.img_size),
                                           num_gpu=strategy.num_replicas_in_sync,
                                           batch_size=args.batch_size, frame_depth=args.frame_depth,
                                           temporal=args.temporal, respiration=args.respiration, shuffle=args.shuffle,
                                           crop_size=(args.cropped_size, args.cropped_size), data_aug=args.data_aug,
                                           dataset=args.tr_dataset, step_size=args.step_size_train, split_iter=True, seg_masks=seg_masks,
                                           use_first_derivative=args.predict_first_derivative, 
                                           use_second_derivative=args.predict_second_derivative, 
                                           use_raw_signal=args.predict_raw_signal, 
                                           signals_to_use=signals_to_use_train)
        validation_generator = DataGenerator(path_of_video_test, nframe_per_video_ts, (args.img_size, args.img_size),
                                            num_gpu=strategy.num_replicas_in_sync,
                                            batch_size=args.batch_size, frame_depth=args.frame_depth,
                                            temporal=args.temporal, respiration=args.respiration, shuffle=args.shuffle,
                                            crop_size=(args.cropped_size, args.cropped_size), data_aug=False,
                                            dataset=args.ts_dataset, step_size=args.step_size_train, split_iter=True, 
                                            use_first_derivative=args.predict_first_derivative, 
                                            use_second_derivative=args.predict_second_derivative, 
                                            use_raw_signal=args.predict_raw_signal, 
                                            signals_to_use=signals_to_use_test)

        # generator_types = (tf.float32, {tv: tf.float32 for tv in signals_to_use_train})
        # training_generator = tf.data.Dataset.from_generator(lambda: (x for x in training_generator), output_types=generator_types)
        # validation_generator = tf.data.Dataset.from_generator(lambda: (x for x in validation_generator), output_types=generator_types)

        print("Training generator shape: ", training_generator.__getitem__(0)[1].keys())
        # print("Training generator shape: ", training_generator.__getitem__(0)[1]['dysub'].shape)
        # for _, y_batch in validation_generator:
        #     pass
        # %%  Checkpoint Folders
        checkpoint_folder = str(os.path.join(args.save_dir, args.exp_name))
        if not os.path.exists(checkpoint_folder):
            os.makedirs(checkpoint_folder)
        cv_split_path = str(os.path.join(checkpoint_folder, "cv_" + str(cv_split)))


        #%% Callbacks
        if args.save_all == 1:
            save_best_callback = tf.keras.callbacks.ModelCheckpoint(filepath=cv_split_path+"_epoch{epoch:02d}_model.hdf5",
                                                                save_best_only=False, verbose=1)
        else:
            save_best_callback = tf.keras.callbacks.ModelCheckpoint(filepath=cv_split_path+"_last_model.hdf5",
                                                                save_best_only=False, verbose=1)
        csv_logger = tf.keras.callbacks.CSVLogger(filename=cv_split_path+'_train_loss_log.csv')
        hb_callback = HeartBeat(training_generator, validation_generator, args, str(cv_split), checkpoint_folder)

        #%% save tensorbboard
        log_dir = "./logs/tensorboard/" + str(args.exp_name) +"-" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

        #%% Model Training and Saving Results
        history = model.fit(x=training_generator, validation_data=validation_generator, epochs=args.nb_epoch, verbose=1, validation_steps=100, 
                            callbacks=[csv_logger, save_best_callback,], steps_per_epoch=None,
                            initial_epoch=args.initial_epoch, validation_freq=1, 
                            use_multiprocessing=False, workers=8, max_queue_size=1000)

        print(history.history)
        val_loss_history = history.history['val_loss']
        val_loss = np.array(val_loss_history)
        np.savetxt((cv_split_path+'_val_loss_log.csv'), val_loss, delimiter=",")
        # save train/test metrics as .csv file
        train_metric_keys = [i for i in history.history.keys() if not i.startswith("val_")]
        test_metric_keys = [i for i in history.history.keys() if i.startswith("val_")]
        train_metric_df = pd.DataFrame.from_dict({k: history.history[k] for k in train_metric_keys}, orient="columns")
        test_metric_df = pd.DataFrame.from_dict({k: history.history[k] for k in test_metric_keys}, orient="columns")
        metric_df = pd.merge(train_metric_df, test_metric_df, how="outer", left_index=True, right_index=True)
        metric_df.to_csv(cv_split_path+'_loss.csv', header=True)

        score = model.evaluate(validation_generator, steps=1000, verbose=1)
        print(score)

        print('****************************************')
        if isinstance(score, float):
            score = [score]
        for metric_val, sig in zip(score, model.metrics_names):
            print("{} value: {:.3f}".format(sig, metric_val))

        print('****************************************')
        print('Start saving predicitions from the last epoch')

        training_generator = DataGenerator(path_of_video_tr, nframe_per_video_tr, (args.img_size, args.img_size),
                                           num_gpu=strategy.num_replicas_in_sync,
                                           batch_size=args.batch_size, frame_depth=args.frame_depth,
                                           temporal=args.temporal, respiration=args.respiration, shuffle=False,
                                           crop_size=(args.cropped_size, args.cropped_size), data_aug=args.data_aug,
                                           dataset=args.tr_dataset, step_size=args.step_size_test, 
                                           use_second_derivative=args.predict_second_derivative, 
                                           use_raw_signal=args.predict_raw_signal,  
                                           signals_to_use=signals_to_use_train)

        validation_generator = DataGenerator(path_of_video_test, nframe_per_video_ts, (args.img_size, args.img_size),
                                             num_gpu=strategy.num_replicas_in_sync,
                                             batch_size=args.batch_size, frame_depth=args.frame_depth,
                                             temporal=args.temporal, respiration=args.respiration, shuffle=False,
                                             crop_size=(args.cropped_size, args.cropped_size), data_aug=False,
                                             dataset=args.ts_dataset, step_size=args.step_size_test, 
                                             use_second_derivative=args.predict_second_derivative, 
                                             use_raw_signal=args.predict_raw_signal,
                                             signals_to_use=signals_to_use_test)

        num_steps = None
        # yptrain = model.predict(training_generator, steps=num_steps, verbose=1)
        # # fix issue with saving segmentation masks to .mat file
        # if seg_masks is not None:
        #     yptrain = yptrain[:-2]
        # for i, x in enumerate(yptrain):
        #     print(i, x.shape)
        # scipy.io.savemat(checkpoint_folder + '/yptrain_best_' + '_cv' + str(cv_split) + '.mat',
        #                  mdict={'yptrain': yptrain})
        yptest = model.predict(validation_generator, steps=num_steps, verbose=1)
        # if we only have one signal, make sure a dummy axis is added 
        if not isinstance(yptest, list):
            if len(yptest.shape) == 3 and yptest.shape[2] == 1:
                yptest = np.expand_dims(yptest, axis=0)
        # fix issue with saving segmentation masks to .mat file
        if seg_masks is not None:
            yptest = yptest[:-2]
        # scipy.io.savemat(checkpoint_folder + '/yptest_best_' + '_cv' + str(cv_split) + '.mat',
        #                  mdict={'yptest': yptest})

        print('Finish saving the results from the last epoch')

        print('=====Start Evaulating======')

        # Metric Evaluation
        if args.temporal in ['2DCNN-MT', 'TSM-MT', 'MIX-MT', 'MIX-MT-Dual','ViViT'] or args.temporal.startswith('MIX-MT-Dual-RNN'):
            # pulse_labels, resp_labels = retrive_labels(path_of_video_test, multi_task=True)

            validation_generator = DataGenerator(path_of_video_test, nframe_per_video_ts, (args.img_size, args.img_size),
                                             num_gpu=strategy.num_replicas_in_sync,
                                             batch_size=args.batch_size, frame_depth=args.frame_depth,
                                             temporal=args.temporal, respiration=args.respiration, shuffle=False,
                                             crop_size=(args.cropped_size, args.cropped_size), data_aug=False,
                                             dataset=args.ts_dataset, step_size=args.step_size_test, return_file_names=True,
                                             use_second_derivative=args.predict_second_derivative, 
                                             use_raw_signal=args.predict_raw_signal, 
                                             signals_to_use=signals_to_use_test)

            # iterate through validation generator (should be same order as used in model.predict)
            # and grab each target vector from each batch, and concatenate together
            labels_to_save = signals_to_use_test + ["file"]
            if args.predict_raw_signal:
                for sig in signals_to_use_test:
                    labels_to_save.append(f"{sig}_raw")
            if args.predict_second_derivative:
                for sig in signals_to_use_test:
                    labels_to_save.append(f"{sig}_SD")

            ground_truth_labels = {f"{l}_label": [] for l in labels_to_save}
            batch_count = 0
            for i in range(len(validation_generator)):
                print("Val gen batch {}/{}".format(i, len(validation_generator)))
                _, y_batch = validation_generator.__getitem__(i)
                if batch_count == num_steps:
                    break
                else:
                    for l in labels_to_save:
                        ground_truth_labels[f"{l}_label"].append(y_batch[l])
                batch_count +=1

            # concatenate all batches together
            for l in labels_to_save:
                ground_truth_labels[f"{l}_label"] = np.concatenate(ground_truth_labels[f"{l}_label"], axis=0)

            # save the predicted values for each signal
            for i, sig in enumerate(model.output_names):
            # for i, sig in enumerate(signals_to_use_train):
                ground_truth_labels[f"{sig}_pred"] = yptest[i]

        else:
            pulse_labels = retrive_labels(path_of_video_test, multi_task=False)
            pulse_preds = yptest[0]

        # pulse_labels = ground_truth_labels["dysub_label"]
        # resp_labels = ground_truth_labels["drsub_label"]

        # MAE, RMSE, meanSNR, HR_SNR, HR0, HR = calculate_metric(pulse_preds, pulse_labels, signal='pulse',
        #                                                        window_size=args.eval_window,
        #                                                        fs=args.data_fs, bpFlag=True)
        # ground_truth_labels['hr_SNR'] = HR_SNR
        # ground_truth_labels['hr_GT'] = HR0
        # ground_truth_labels['hr_pred'] = HR
        # print('========HR Report==========')
        # print('MAE: ', MAE)
        # print('RMSE: ', RMSE)
        # print('meanSNR: ', meanSNR)
        # print('========HR Report==========')
        scipy.io.savemat(checkpoint_folder + '/metric_HR.mat', mdict=ground_truth_labels)

        # if args.temporal in ['2DCNN-MT', 'TSM-MT', 'MIX-MT', 'MIX-MT-Dual-RNN', 'MIX-MT-Dual-RNN_v2', 'MIX-MT-Dual-RNN_v3']:
            # MAE_RR, RMSE_RR, meanSNR_RR, RR_SNR, RR0, RR = calculate_metric(resp_preds, resp_labels, signal='resp',
            #                                                                 window_size=args.eval_window,
            #                                                                 fs=args.data_fs, bpFlag=True)
            # ground_truth_labels['rr_SNR'] = RR_SNR
            # ground_truth_labels['rr_GT'] = RR0
            # ground_truth_labels['rr_pred'] = RR
            # print('========RR Report==========')
            # print('MAE: ', MAE_RR)
            # print('RMSE: ', RMSE_RR)
            # print('meanSNR: ', meanSNR_RR)
            # print('========RR Report==========')

            # scipy.io.savemat(checkpoint_folder + '/metric_RR.mat', mdict=ground_truth_labels)


# %% Training
if __name__ == "__main__":
    print('Using Split ', str(args.cv_split))
    # if args.data_aug == 1:
    #     train(args, args.cv_split, img_rows=args.cropped_size, img_cols=args.cropped_size)
    # else:
    train(args, args.cv_split, img_rows=args.img_size, img_cols=args.img_size)




