# %% imports
import argparse
import itertools
import json
import os
import datetime
import glob
import sys

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import scipy.io
from einops.layers.tensorflow import Rearrange
from tensorflow_addons.layers import GELU

sys.path.append("../")
from data_generator import DataGenerator
from metric import calculate_metric
from model import HeartBeat, DeepPhy, DeepPhy_3DCNN, DeepPhy_3DCNN_MT, DeepPhys_2DCNN_MT, \
    Hybrid_CAN, Hybrid_CAN_MT, Hybrid_CAN_MT_Dual, TS_CAN, MTTS_CAN, MTTS_CAN_Dual, Hybrid_CAN_MT_Dual_RNN, TS_CAN_PEAKDETECTION
from pre_process import get_nframe_video, retrive_labels, read_from_txt
from model import Attention_mask, dice_coef_loss, first_and_second_derivative_loss, first_derivative_loss, second_derivative_loss, second_derivative_peak_loss

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = ""

# %% create args class
parser = argparse.ArgumentParser()
# data I/O
parser.add_argument('-exp', '--exp_name', type=str, default='Test',
                    help='experiment name')
parser.add_argument('-m', '--model_file', type=str, 
                    help='Path to saved model file', 
                    default='C:\\Users\\username\Downloads\\trainSyn2345678900_testAFRL_ppg_resp_ecg_abp_8epoch_modelv1\cv_0_epoch08_model.hdf5')
parser.add_argument('-i', '--data_dir', type=str,
                    default='C:\\Users\\username\Downloads\Datasets\AFRL\ProcessedInputFiles\AFRLChunks36x36', 
                    help='Location for the dataset')
parser.add_argument('-o', '--output_dir', type=str, default="./", help='Location to save the predictions')
parser.add_argument('-img', '--img_size', type=int, default=36, help='img_size')
parser.add_argument('-x', '--batch_size', type=int, default=8,
                    help='batch')
parser.add_argument('-fd', '--frame_depth', type=int, default=30,
                    help='frame_depth for 3DCNN')
parser.add_argument('-sst', '--step_size_test', type=int, default=None,
                    help='number of frames in between windows in video segment during testing')
parser.add_argument('-temp', '--temporal', type=str, default='MIX-MT-Dual-RNN_v2',
                    help='3DCNN, 2DCNN or mix')
parser.add_argument('-resp', '--respiration', type=int, default=0,
                    help='train with resp or not')
parser.add_argument('-crp_img', '--cropped_size', type=int, default=36, help='img_size')
parser.add_argument('-tr_data', '--tr_dataset', type=str, default='AFRL', help='training dataset name')
parser.add_argument('-ts_data', '--ts_dataset', type=str, default='AFRL', help='test dataset name')
parser.add_argument('--predict_first_derivative', action='store_true', help='Add additional output target signal for second derivative')
parser.add_argument('--predict_second_derivative', action='store_true', help='Add additional output target signal for second derivative')
parser.add_argument('--predict_raw_signal', action='store_true', help='Add additional output target signal for raw output (0th derivative')
parser.add_argument('--use_second_derivative_frames', action='store_true', help='Use second derivative frames as input')
parser.add_argument('-tss', '--signals_to_use_test', nargs="+", default=["dysub", "drsub"],
                    help='List of target signals to use for testing - note for backwards compatability, use ppg->dysub and resp->drsub')

args = parser.parse_args()
print('input args:\n', json.dumps(vars(args), indent=4, separators=(',', ':')))  # pretty print args

# if test step size is not supplied, default should be the frame depth 
# i.e. no overlapping windows in test output
if not args.step_size_test:
    args.step_size_test = args.frame_depth

# create output_dir if it does not exist
os.makedirs(args.output_dir, exist_ok=True)

path_of_video_test = glob.glob(os.path.join(args.data_dir, "*"))
print(path_of_video_test[0:10])
nframe_per_video_ts = get_nframe_video(path_of_video_test[0], dataset=args.ts_dataset)

signals_to_use_test = args.signals_to_use_test

validation_generator = DataGenerator(path_of_video_test, nframe_per_video_ts, (args.img_size, args.img_size),
                                    num_gpu=1,
                                    batch_size=args.batch_size, frame_depth=args.frame_depth,
                                    temporal=args.temporal, respiration=args.respiration, shuffle=False,
                                    crop_size=(args.cropped_size, args.cropped_size), data_aug=False,
                                    dataset=args.ts_dataset, step_size=args.step_size_test, 
                                    use_second_derivative=args.predict_second_derivative, 
                                    use_raw_signal=args.predict_raw_signal,
                                    return_file_names=False,
                                    signals_to_use=signals_to_use_test)


model_file = args.model_file
model = tf.keras.models.load_model(model_file, custom_objects={
    'Attention_mask': Attention_mask, 
    'dice_coef_loss': dice_coef_loss,
    'Rearrange': Rearrange, 
    'GELU': GELU, 
    'first_and_second_derivative_loss': first_and_second_derivative_loss, 
    'my_loss_function': first_and_second_derivative_loss, 
    'first_derivative_loss': first_derivative_loss,
    'second_derivative_loss': second_derivative_loss, 
    'second_derivative_peak_loss': second_derivative_peak_loss})
print(model.summary())

print(validation_generator.__getitem__(0)[1].keys())

num_steps = None
yptest = model.predict(validation_generator, steps=num_steps, batch_size=args.batch_size, verbose=1)

# if we only have one target signal, add additional dimension 
# to prediction array
if not isinstance(yptest, list):
    if len(yptest.shape) < 4:
        yptest = np.expand_dims(yptest, 0)
        print("yptest shape:", yptest.shape)

# %% get ground truth values from generator
validation_generator = DataGenerator(path_of_video_test, nframe_per_video_ts, (args.img_size, args.img_size),
                                    num_gpu=1,
                                    batch_size=args.batch_size, frame_depth=args.frame_depth,
                                    temporal=args.temporal, respiration=args.respiration, shuffle=False,
                                    crop_size=(args.cropped_size, args.cropped_size), data_aug=False,
                                    dataset=args.ts_dataset, step_size=args.step_size_test,
                                    use_second_derivative=args.predict_second_derivative, 
                                    use_raw_signal=args.predict_raw_signal, return_file_names=True,
                                    signals_to_use=signals_to_use_test)

# iterate through validation generator (should be same order as used in model.predict)
# and grab each target vector from each batch, and concatenate together
labels_to_save = signals_to_use_test + ["file"]
if args.predict_raw_signal:
    for sig in signals_to_use_test:
        labels_to_save.append(f"{sig}_raw")
if args.predict_second_derivative:
    for sig in signals_to_use_test:
        labels_to_save.append(f"{sig}_SD")
ground_truth_labels = {f"{l}_label": [] for l in labels_to_save}
batch_count = 0
for i in range(len(validation_generator)):
    # print("Val gen batch {}/{}".format(i, len(validation_generator)))
    _, y_batch = validation_generator.__getitem__(i)
    if batch_count == num_steps:
        break
    else:
        for l in labels_to_save:
            ground_truth_labels[f"{l}_label"].append(y_batch[l])
        for k in y_batch.keys():
            if k == "ecg30":
                if k not in ground_truth_labels.keys():
                    print("adding {} to labels".format(k))
                    print("K", len(y_batch[k]))
                    ground_truth_labels[k] = [np.squeeze(y_batch[k])]
                else:
                    ground_truth_labels[k].append(np.array(np.squeeze(y_batch[k])))
            elif k == "ecgBeats":
                print(k, y_batch[k])
                if k not in ground_truth_labels.keys():
                    ground_truth_labels[k] = [y_batch[k]]
                else:
                    ground_truth_labels[k].append(y_batch[k])
    batch_count +=1

# concatenate all batches together
for k in ground_truth_labels.keys():
    print(k, len(ground_truth_labels[k]))
    print(type(ground_truth_labels[k][0]))
    
    if k == "ecgBeats":
        ground_truth_labels[k] = np.squeeze(np.array(ground_truth_labels[k], dtype=object)).flatten()
        print(ground_truth_labels[k].shape)
    else:
        print("merging ", k)
        ground_truth_labels[k] = np.concatenate(ground_truth_labels[k], axis=0)

# save the predicted values for each signal
for i, sig in enumerate(model.output_names):
    print("Saving ", sig, yptest[i].shape)
    ground_truth_labels[f"{sig}_pred"] = yptest[i]

scipy.io.savemat(os.path.join(args.output_dir, 'metric_HR.mat'), mdict=ground_truth_labels)
print("Finished.")