# %% imports
import argparse
import json
import os
import glob
import sys

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from einops.layers.tensorflow import Rearrange
from tensorflow_addons.layers import GELU
import imageio

sys.path.append("../")
from data_generator import DataGenerator
from metric import calculate_metric
from model import HeartBeat, DeepPhy, DeepPhy_3DCNN, DeepPhy_3DCNN_MT, DeepPhys_2DCNN_MT, \
    Hybrid_CAN, Hybrid_CAN_MT, Hybrid_CAN_MT_Dual, TS_CAN, MTTS_CAN, MTTS_CAN_Dual, Hybrid_CAN_MT_Dual_RNN, TS_CAN_PEAKDETECTION
from pre_process import get_nframe_video, retrive_labels, read_from_txt
from model import Attention_mask, dice_coef_loss, first_and_second_derivative_loss, first_derivative_loss, second_derivative_loss
sys.path.append("./")
from calculate_metrics import calc_PPG_peaks, get_participant_id, get_chunk_id, get_task_id, get_window_number

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = ""

# %% create args class
parser = argparse.ArgumentParser()
# data I/O
parser.add_argument('-exp', '--exp_name', type=str, default='Test',
                    help='experiment name')
parser.add_argument('-m', '--model_file', type=str, 
                    help='Path to saved model file', 
                    default='C:\\Users\\username\Downloads\\trainSyn2345678900_testAFRL_ppg_resp_ecg_abp_8epoch_modelv1\cv_0_epoch08_model.hdf5')
parser.add_argument('-i', '--data_dir', type=str,
                    default='C:\\Users\\username\Downloads\Datasets\AFRL\ProcessedInputFiles\AFRLChunks36x36', 
                    help='Location for the dataset')
parser.add_argument('-o', '--output_dir', type=str, default="./", help='Location to save the predictions')
parser.add_argument('-img', '--img_size', type=int, default=36, help='img_size')
parser.add_argument('-x', '--batch_size', type=int, default=8,
                    help='batch')
parser.add_argument('-fd', '--frame_depth', type=int, default=30,
                    help='frame_depth for 3DCNN')
parser.add_argument('-sst', '--step_size_test', type=int, default=None,
                    help='number of frames in between windows in video segment during testing')
parser.add_argument('-temp', '--temporal', type=str, default='MIX-MT-Dual-RNN_v2',
                    help='3DCNN, 2DCNN or mix')
parser.add_argument('-resp', '--respiration', type=int, default=0,
                    help='train with resp or not')
parser.add_argument('-crp_img', '--cropped_size', type=int, default=36, help='img_size')
parser.add_argument('-tr_data', '--tr_dataset', type=str, default='AFRL', help='training dataset name')
parser.add_argument('-ts_data', '--ts_dataset', type=str, default='AFRL', help='test dataset name')
parser.add_argument('-tss', '--signals_to_use_test', nargs="+", default=["dysub", "drsub"],
                    help='List of target signals to use for testing - note for backwards compatability, use ppg->dysub and resp->drsub')
parser.add_argument('--predict_second_derivative', action='store_true', help='Add additional output target signal for second derivative')
parser.add_argument('--predict_raw_signal', action='store_true', help='Add additional output target signal for raw output (0th derivative')

args = parser.parse_args()
print('input args:\n', json.dumps(vars(args), indent=4, separators=(',', ':')))  # pretty print args

# if test step size is not supplied, default should be the frame depth 
# i.e. no overlapping windows in test output
if not args.step_size_test:
    args.step_size_test = args.frame_depth

# create output_dir if it does not exist
os.makedirs(args.output_dir, exist_ok=True)

path_of_video_test = glob.glob(os.path.join(args.data_dir, "P004T*"))
print(path_of_video_test[0:10])
nframe_per_video_ts = get_nframe_video(path_of_video_test[0], dataset=args.ts_dataset)

signals_to_use_test = args.signals_to_use_test

validation_generator = DataGenerator(path_of_video_test, nframe_per_video_ts, (args.img_size, args.img_size),
                                    num_gpu=1,
                                    batch_size=args.batch_size, frame_depth=args.frame_depth,
                                    temporal=args.temporal, respiration=args.respiration, shuffle=False,
                                    crop_size=(args.cropped_size, args.cropped_size), data_aug=False,
                                    dataset=args.ts_dataset, step_size=args.step_size_test,
                                    use_second_derivative=args.predict_second_derivative, 
                                    use_raw_signal=args.predict_raw_signal,
                                    return_file_names=True,
                                    signals_to_use=signals_to_use_test)


model_file = args.model_file
model = tf.keras.models.load_model(model_file, custom_objects={
    'Attention_mask': Attention_mask, 
    'dice_coef_loss': dice_coef_loss,
    'Rearrange': Rearrange, 
    'GELU': GELU, 
    'first_and_second_derivative_loss': first_and_second_derivative_loss, 
    'my_loss_function': first_and_second_derivative_loss, 
    'first_derivative_loss': first_derivative_loss,
    'second_derivative_loss': second_derivative_loss})
print(model.summary())

print(validation_generator.__getitem__(0)[1].keys())

num_batches = 1
num_img_per_batch = 6
waveform_res = {}
for i in range(num_batches):
    print("Batch:", i)
    # get a batch of images
    img = validation_generator.__getitem__(i)[0]
    file_labels = validation_generator.__getitem__(i)[1]["file"]
    print(file_labels)
    
    for j in range(num_img_per_batch):
        print("FL", file_labels) 
        print(file_labels[j])
        participant_id = get_participant_id(file_labels[j])
        try:
            chunk_id = int(get_chunk_id(file_labels[j]).split("C")[1])
        except IndexError:
            chunk_id = 0
        window_number = int(get_window_number(file_labels[j]))
        os.makedirs(os.path.join(args.output_dir, participant_id), exist_ok=True)
        # save raw data
        np.save(os.path.join(args.output_dir, participant_id, f"{participant_id}_{chunk_id:05d}_{window_number:05d}_input.npy"), img[1][j])

    for layer_name in ['attention_mask_1_shared', 'attention_mask_2_dysub',]:
        # create new model using intermediate layer (i.e. attention mask)
        # as the model output
        intermediate_layer_model1 = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
        # generate intermediate layer prediction
        pred_attention_mask = intermediate_layer_model1.predict(img)
        print(layer_name, "attention mask shape:", pred_attention_mask.shape)
        # for each image in batch
        for j in range(num_img_per_batch):
            print(file_labels[j])
            participant_id = get_participant_id(file_labels[j])
            try:
                chunk_id = int(get_chunk_id(file_labels[j]).split("C")[1])
            except IndexError:
                chunk_id = 0
            window_number = int(get_window_number(file_labels[j]))

            # if there is also a task, include task in output file name
            if "T" in file_labels[j]:
                task_id = int(get_task_id(file_labels[j]).split("T")[1])
                # save the raw data
                np.save(os.path.join(args.output_dir, participant_id, f"{participant_id}_{task_id:05d}_{chunk_id:05d}_{window_number:05d}_{layer_name}.npy"), pred_attention_mask[j])
            else:
                # save the raw data
                np.save(os.path.join(args.output_dir, participant_id, f"{participant_id}_{chunk_id:05d}_{window_number:05d}_{layer_name}.npy"), pred_attention_mask[j])

# %% plot sequence of attention masks, equally spaced in time
def plot_attention_masks(dir, num_masks):
    from matplotlib import gridspec
    files = glob.glob(os.path.join(dir, "*_attention_mask_1_shared.npy"))
    print(f"Found {len(files)} files")
    if len(files) == 0:
        return
    files = sorted(files)
    print(files)
    # get num_masks evenly spaced attention masks
    indices = np.round(np.linspace(0, len(files)-1, num_masks)).astype(int)
    print(indices)
    fig, ax = plt.subplots(1, num_masks, figsize=(num_masks*3, 3), 
        gridspec_kw=dict(hspace=0, wspace=0))
    fig.subplots_adjust(hspace=0, wspace=0)
    for i, idx in enumerate(indices):
        pred_attention_mask = np.load(files[idx])
        ax[i].imshow(pred_attention_mask,)
        ax[i].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)

    plt.savefig(os.path.join(dir, f"attention_mask_plot.svg"))
    plt.show()


d = "C:\\Users\\username\Downloads\\V4V_test_attention_masks\\11554457"
data_dirs = glob.glob("C:\\Users\\username\Downloads\\V4V_test_attention_masks\*")
for d in data_dirs:
    plot_attention_masks(d, num_masks=7)

# %% plot sequence of attention masks, equally spaced in time
def plot_attention_mask_videos(dir, participant_id, task, save_dir):
    from matplotlib import gridspec
    os.makedirs(save_dir, exist_ok=True)

    task_string = f"{task:05d}"
    mask1_files = glob.glob(os.path.join(dir, f"{participant_id}_{task_string}_*_attention_mask_1_shared.npy"))
    print(f"Found {len(mask1_files)} files")
    if len(mask1_files) == 0:
        return
    mask1_files = sorted(mask1_files)
    print(mask1_files)

    mask2_files = glob.glob(os.path.join(dir, f"{participant_id}_{task_string}_*_attention_mask_2_dysub.npy"))
    print(f"Found {len(mask2_files)} files")
    if len(mask2_files) == 0:
        return
    mask2_files = sorted(mask2_files)
    print(mask2_files)

    # for each window
    for i, idx in enumerate(mask1_files):
        print(mask1_files[i])
        print(mask2_files[i])
        pred_attention_mask1 = np.load(mask1_files[i])
        print("pam1 shape:", pred_attention_mask1.shape)

        pred_attention_mask2 = np.load(mask2_files[i])
        print("pam2 shape:", pred_attention_mask2.shape)
        # for each frame in window
        for t in range(pred_attention_mask1.shape[2]):
            fig, ax = plt.subplots(1, 2, figsize=(16, 8),
                gridspec_kw=dict(hspace=0, wspace=0))
            fig.subplots_adjust(hspace=0, wspace=0)
            ax[0].imshow(pred_attention_mask1[:, :, t, 0],)
            ax[0].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)

            ax[1].imshow(pred_attention_mask2[:, :, t, 0],)
            ax[1].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)

            plt.savefig(os.path.join(save_dir, f"attention_mask_plot_{i:07d}_{t:07d}.png"))
            # plt.show()
            plt.close()


d = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppgSD_8epoch_modelv1_SDframes/attention_masks/P004/"
# data_dirs = glob.glob("C:\\Users\\username\Downloads\\V4V_test_attention_masks\*")
# for d in data_dirs:
save_dir = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppgSD_8epoch_modelv1_SDframes/attention_masks/P004/T1/images_for_attention_gif"
plot_attention_mask_videos(d, participant_id="P004", task=1,
    save_dir=save_dir)
# %% create gif animation
images = sorted(list(glob.glob(os.path.join(save_dir, '*.png'))))
image_list = []
for file_name in images:
    image_list.append(imageio.imread(file_name))

gif_file_name = 'P004T1.gif'
imageio.mimwrite(os.path.join("./", gif_file_name), image_list, fps=30)
