"""
File to run multiple prediction tests with the same parameters
Stores the following in the tmp_results folder
    1. multiple_gifs: folder containing all the predicted gifs
    2. multiple_img_npys: folder containing all the npy files for the predicted images and variance images
    3. multiple_errors: folder containing npy files that have all the error rates for every prediction sequence 
    4. multple_mean_stds: folder containing the npy files that have all the mean stds for every image in the prediction sequence


Functionality: 
All setable parameters at the top: 
1. Filenames and Datasets
2. Training parameters
3. Test parameters
4. Flag for training and testing using random parts of the sequence or just the start/some specific point

NOTE: look to the ipynb files for guidance. 
"""

import os 
import sys 
import numpy as np 
import cv2 
import matplotlib.pyplot as plt
import GPy
from copy import deepcopy
from scipy.io import loadmat
import pickle
import time 
import pdb

# append path to access our written functions
current_dir = os.getcwd()
parent_dir = os.path.join(current_dir, "..")
separate_dir = os.path.join(parent_dir, "separate_model_prediction") # directory containing the separate predictor classes 
parent_parent_dir = os.path.join(parent_dir, "..")
sys.path.append(current_dir)
sys.path.append(parent_dir)
sys.path.append(separate_dir)
sys.path.append(parent_parent_dir)

# import our written code 
from patchify_unpatchify import patcher 
from predict import predict as base_predict 
from processing import processing 
from nonparam_wrapper import nonparam_predictor, gaussian_kernel_2d

from nonparam_wrapper_particle import nonparam_predictor_sampling
from nonparam_wrapper_particle_difference import nonparam_predictor_sampling_difference
from nonparam_wrapper_separate import nonparam_predictor_separate, gaussian_kernel_2d
from nonparam_wrapper_particle_difference_separate import nonparam_predictor_sampling_difference_separate

# import data processing and display functions
from display_functions import * 
from ns_dataset_processing import * 

#################################################### START: Modifiable parameters ####################################################
num_predictions = 25
prediction_steps = 15 #20 #15 # how many steps in the future to predict
num_training_imgs = 10#10
random_img_start = False # boolean: if True chooses a random part of the sequence to train and test with, if False does it starting at 'start_point' index of each 
start_point = 0 # Part of the sequence to evaluate for training and testing

viscosity = 1e-3
nvstokes_filename = "../../../fourier_neural_operator_paper/dataset/ns_data_32by32_visc1e-3_1e-4dt_200T_50samples_2.mat"
nvstokes_filenames = [nvstokes_filename]
multiple_seq_in_file = True # boolean: if True: indicates that there are multiple sequences stored within a single file, if False: indicates each file only stores one data sequence


# sequence to start and end predicting on 
start_prediction_seq = 27#25 
end_prediction_seq = start_prediction_seq + num_predictions

separate = False # boolean: if True - create a separate kernel to be used for the prediction of each separate output dimension, if False - use a single kernel 
difference_prediction = False # booean: if True predict the difference, if False: predict the image directly 

save_dirname = "50_comparison_10train_trainstride(2,2)_WAYLATER" # name of the directory to create and store all the results in 
save_filename_suffix_start = "10train_32by32batchtest_visc1e-3_50test_WAYLATER"#"32by32batchtest_visc1e-3_50test" # starting suffix for all the files to be saved

# Inversion methodology to use: 
svd_inverse = True # boolean: if True: uses the svd to get the inverse and threhsolds using the sigma threhsold, if False: uses the GPy thresholding
sigma_threshold = 1e-4
plot_for_thresholding = False

# boolean: when true you start the test sequence right after the training sequence
# when false you have to specify the index that the test sequence starts at. 
start_test_after_train = True
if not start_test_after_train: 
    # define the test start index number - index of the sequence that will be the first image predicted
    test_start_img_num = 75

#################################################### END: Modifiable parameters ####################################################
#################################################### START: Train and Test Parameters ####################################################
##### Model Training Parameters #####
# Kernel parameters 
starting_kernel_lengthscale = 72
starting_kernel_variance = 69

# patchify parameters
img_dim = (32, 32)
patch_dim = (15, 15)#(15,15)#(16, 16)#(16, 16)#(4,4)#(8,8)
patch_border = (7,7)#(0,0)#(7,7)#(2,2)#(4,4)
img_padlen = (0, 0)
img_padtype = 'wrap'
wrap_x = True
wrap_y = True
stride = (2,2)
patch_weight = np.ones(patch_dim)

patch_parameters = {'img_dim':img_dim, 
                    'patch_dim':patch_dim, 
                    'patch_border':patch_border, 
                    'img_padlen':img_padlen, 
                    'img_padtype':img_padtype, 
                    'wrap_x':wrap_x, 
                    'wrap_y':wrap_y, 
                    'stride':stride, 
                    'patch_weight':patch_weight}

# data processor paramters
xtypes = ['img0', 'img-1', 'img-2']
if difference_prediction: 
    ytype = 'diff'
else: 
    ytype = 'img'

farneback_flow_params = [None, 0.5, 3, 15, 3, 5, 1.2, 0] # NOTE: legacy parameters - not really needed anymore
processor_parameters = {'xtypes':xtypes, 
                        'ytype':ytype, 
                        'farneback_flow_params':farneback_flow_params}

max_all_together = 1000 

if separate: 
    if difference_prediction:
        print("Separate Kernels, Difference Prediction Model \n\n")
        np_predictor = nonparam_predictor_separate(patch_parameters=patch_parameters, 
                                                   processor_parameters=processor_parameters, 
                                                   max_all_together=max_all_together)

    else: 
        print("Separate Kernels, Direct Prediction Model \n\n")
        np_predictor = nonparam_predictor_separate(patch_parameters=patch_parameters, 
                                                   processor_parameters=processor_parameters, 
                                                   max_all_together=max_all_together)
else: 
    if difference_prediction: 
        print("Single Kernel, Difference Prediction Model \n\n")
        np_predictor = nonparam_predictor_sampling_difference(patch_parameters=patch_parameters, 
                                                              processor_parameters=processor_parameters, 
                                                              max_all_together=max_all_together)
    else: 
        print("Single Kernel, Direct Prediction Model \n\n")
        np_predictor = nonparam_predictor_sampling(patch_parameters=patch_parameters, 
                                           processor_parameters=processor_parameters, 
                                           max_all_together=max_all_together)

##### Testing Parameters #####
img_dim_test = img_dim 
patch_dim_test = patch_dim 
patch_border_test = patch_border
img_padlen_test = img_padlen
img_padtype_test = 'wrap'
wrap_x_test = True
wrap_y_test = True
stride_test = (1,1)
patch_weight_test = patch_weight

test_patch_obj = patcher(img_dim=img_dim_test, 
                         patch_dim=patch_dim_test, 
                         patch_border=patch_border_test, 
                         img_padlen=img_padlen_test, 
                         img_padtype=img_padtype_test, 
                         wrap_x=wrap_x_test, 
                         wrap_y=wrap_y_test, 
                         stride=stride_test, 
                         patch_weight=patch_weight_test)


#################################################### END: Train and Test Parameters ####################################################
#################################################### START: Train and Test Data ####################################################
##### Training Data #####

# Load data from file
if not multiple_seq_in_file:
    # Each file stores a single data sequence
    nvstokes_img_seqs = [np_predictor.data_processor.nvstokes_file_to_image_seq(filename) for filename in nvstokes_filenames]
else: 
    # Each file stores multiple data sequences that must be appended individually to the overall list
    nvstokes_img_seqs = []
    for nvstokes_filenames in nvstokes_filenames:
        file_sequences = np_predictor.data_processor.nvstokes_file_to_image_sequences(nvstokes_filenames)
        for file_seq in file_sequences: 
            nvstokes_img_seqs.append(file_seq)

# 1 number per sequence 
train_img_starts = [start_point] * len(nvstokes_filenames)
train_img_stops = list(np.array(train_img_starts) + num_training_imgs)

train_imgseqs = [nvstokes_img_seqs[i][train_img_starts[i]:train_img_stops[i]] for i in range(len(nvstokes_img_seqs))]


##### Testing #####
if start_test_after_train: 
    test_img_starts = list(np.array(train_img_stops) - 3)
else: 
    # use the given index to be the first index predicted by our model 
    test_img_starts = [test_start_img_num]  * len(train_img_stops)
test_img_stops = [-1] * len(nvstokes_filenames)
test_imgseqs = [nvstokes_img_seqs[i][test_img_starts[i]:test_img_stops[i]] for i in range(len(nvstokes_img_seqs))]

test_xtypes = xtypes 
test_ytype = 'img'
test_data_processor = processing(xtypes=test_xtypes, 
                                 ytype=test_ytype, 
                                 farneback_flow_params=farneback_flow_params)

#################################################### END: Train and Test Data ####################################################

#################################################### START: FOR LOOP  ####################################################
# NOTE: Start the loop here - the look goes through every image sequence 
# creates the training and test data 
# creates a new model and trains it 
# performs the prediction 
# saves all the metrics
# moves on to the next sequence
# have parameters specified so you can start and stop at specified points to allow picking up if things crash. 

##### Create the directory to save #####
if not os.path.isdir(save_dirname): 
    os.mkdir(save_dirname)

# Create the subdirectories to store different things
gif_folder = os.path.join(save_dirname, "multiple_gifs") # multiple_gifs: folder containing all the predicted gifs
npy_folder = os.path.join(save_dirname, "multiple_img_npys") # multiple_img_npys: folder containing all the npy files for the predicted images and variance images
error_folder = os.path.join(save_dirname, "multiple_errors") # mulitple_errors: folder containing npy files that have all the error rates for every prediction sequence
mean_stds_folder = os.path.join(save_dirname, "multiple_mean_stds")# multiple_mean_stds: folder containing the npy files that have all the mean stds for every image in the prediction sequence

if not os.path.isdir(gif_folder): 
    os.mkdir(gif_folder)
if not os.path.isdir(npy_folder):
    os.mkdir(npy_folder)
if not os.path.isdir(error_folder):
    os.mkdir(error_folder)
if not os.path.isdir(mean_stds_folder):
    os.mkdir(mean_stds_folder)

print("Created directory structure")

# save all the parameters into the folder - prevents having long suffixes after each name
# save as a pickle file accessed via a dictionary 
all_parameters = {}
all_parameters['nvstokes_filenames'] = nvstokes_filenames
all_parameters['num_predictions'] = num_predictions
all_parameters['prediction_steps'] = prediction_steps 
all_parameters['num_training_imgs'] = num_training_imgs
all_parameters['random_img_start'] = random_img_start
all_parameters['start_point'] = start_point
all_parameters['viscosity'] = viscosity
all_parameters['start_prediction_seq'] = start_prediction_seq
all_parameters['end_prediction_seq'] = end_prediction_seq
all_parameters['separate'] = separate
all_parameters['difference_prediction'] = difference_prediction
all_parameters['patch_parameters'] = patch_parameters
all_parameters['processor_parameters'] = processor_parameters
all_parameters['stride_test'] = stride_test
all_parameters['start_test_after_train'] = start_test_after_train
if not start_test_after_train: 
    all_parameters["test_start_img_num"] = test_start_img_num

parameter_pkl_filename = os.path.join(save_dirname, "all_parameters.pkl")
with open(parameter_pkl_filename, "wb") as pkl_file: 
    pickle.dump(all_parameters, pkl_file)
    pkl_file.close()

# save as a text file to read easily 
def save_dictionary_to_text(save_dic, save_dic_filename): 
    """
    Saves a dictionary and its keys and values as a plaintext for easy reading
    args: 
        - save_dic: dictionary to save
        - save_dic_filename: txt filename to save dictionary to
    """
    with open(save_dic_filename, "w") as dic_file: 
        for key in save_dic.keys():
            value = save_dic[key]
            if type(value) == type(save_dic):
                dic_file.write("\n") # more space to make reading clearer
            
            dic_file.write(str(key) + ": ")
            
            if type(value) == type(save_dic): 
                dic_file.write("\n")
                # iterate through the subdictionary and write its contents to file  
                # NOTE: only one level of subdictionaries are supported
                for sub_key in value.keys(): 
                    sub_value = value[sub_key]
                    dic_file.write(str(sub_key) + ": " + str(sub_value) + "\n")
                dic_file.write("\n")
            else:
                dic_file.write(str(value) + "\n")
        dic_file.write("EOF")
    return 
parameter_txt_filename = os.path.join(save_dirname, "all_parameters.txt")
save_dictionary_to_text(save_dic=all_parameters, save_dic_filename=parameter_txt_filename)


assert(len(nvstokes_img_seqs) >= num_predictions)
optimize_model = True
model_variance = 2e-15

use_sparse_GP = False
num_inducing_points = 512
max_opt_iters = 2000

# put model prediction into a fucnction so that memory is cleared
def run_prediction(seq_num):
    """ 
    Runs prediction on a certain sequence number - pulls mostly from global variables 
    Done so that python garbage collector can clear cuda memory after the model is done better ? 
    """
    if separate:
        # Create a separate kernel to be used for the prediction of each separate output dimension
        num_output_dims = np.product(np_predictor.patch_obj.get_ypatch_dim())
        base_kernel = GPy.kern.RBF
        kernel_groups = [xtypes] * num_output_dims
        kernel_lengthscales = [starting_kernel_lengthscale] * num_output_dims
        kernel_variances = [starting_kernel_variance] * num_output_dims

        kernels = []
        for i in range(num_output_dims):
            kernel = np_predictor.create_kernel(base_kernel=base_kernel, 
                                                kernel_groups=[kernel_groups[i]], 
                                                kernel_lengthscales=[kernel_lengthscales[i]], 
                                                kernel_variances=kernel_variances[i])
            kernels.append(kernel)

        gp_model = np_predictor.train_model(kernels=kernels, 
                                            datapoints=train_datapoints, 
                                            optimize=optimize_model, 
                                            noise_var=noise_var, 
                                            use_sparse_GP=use_sparse_GP, 
                                            num_inducing_points=num_inducing_points, 
                                            max_opt_iters=max_opt_iters)
    else: 
        # Use a single kernel - every output dimension shares the same similarity metric/kernel
        base_kernel = GPy.kern.RBF
        kernel_groups = [xtypes]
        kernel_lengthscales = [starting_kernel_lengthscale]
        kernel_variances = [starting_kernel_variance]
        kernel = np_predictor.create_kernel(base_kernel=base_kernel, 
                                            kernel_groups=kernel_groups, 
                                            kernel_lengthscales=kernel_lengthscales,
                                            kernel_variances=kernel_variances)

        gp_model = np_predictor.train_model(kernel=kernel, 
                                            datapoints=train_datapoints, 
                                            optimize=optimize_model, 
                                            noise_var=model_variance, 
                                            use_sparse_GP=use_sparse_GP, 
                                            num_inducing_points=num_inducing_points, 
                                            max_opt_iters=max_opt_iters)


    #################################################### START: Sequential Prediction ####################################################
    # Perform the prediction 
    starting_x_images = test_imgseqs[seq_num][:3] # NOTE - the shifting back into the training sequence is done when declaring what the test and train sequences are 
    predict_separate = True # boolean indicates whether to compute each column of the matrices separately - make this True
    predicted_seq_imgs, predicted_seq_vars = np_predictor.mean_var_propogation_Faster(starting_x_images=starting_x_images, 
                                                                                      steps=prediction_steps, 
                                                                                      all_together=False, 
                                                                                      use_variance_weighting=False, 
                                                                                      test_patch_obj=test_patch_obj, 
                                                                                      save_filename_suffix=save_filename_suffix, 
                                                                                      dirname=npy_folder, 
                                                                                      num_patches_list=[1,1,1], 
                                                                                      seperate=predict_separate, 
                                                                                      svd_inverse=svd_inverse, 
                                                                                      sigma_threshold=sigma_threshold, 
                                                                                      plot_for_thresholding=plot_for_thresholding, 
                                                                                      show_intermediate_imgs=False)    
    return predicted_seq_imgs, predicted_seq_vars

for seq_num in range(start_prediction_seq, end_prediction_seq): 

    ##### Create Training Dataset #####
    train_imgseq = train_imgseqs[seq_num]
    train_x, train_y = np_predictor.data_processor.create_xy(train_imgseq)
    train_x_patch = np_predictor.patch_obj.patchify_dataset(dataset=train_x, dataset_type='x')
    train_y_patch = np_predictor.patch_obj.patchify_dataset(dataset=train_y, dataset_type='y')
    train_x_patch_vecs_all = np_predictor.data_processor.convert_imgdataset_to_vecdataset(dataset=train_x_patch)
    train_y_patch_vecs_all = np_predictor.data_processor.convert_imgdataset_to_vecdataset(dataset=train_y_patch)
    train_x_patch_vecs = train_x_patch_vecs_all
    train_y_patch_vecs = train_y_patch_vecs_all

    train_datapoints = {'x_patchvecs_dataset':train_x_patch_vecs, 
                        'y_patchvecs_dataset':train_y_patch_vecs}

    print("\n\nTRAINING DATAPOINTS: " + str(len(train_x_patch_vecs)) + "\n\n")

    ##### Create Test Dataset for Display #####
    test_ximages, test_yimages = test_data_processor.create_xy(image_seq=test_imgseqs[seq_num])
    # Create the name to save 
    save_filename_suffix = save_filename_suffix_start + "_" + str(seq_num) # no need to have all the parameters here since stored in the outside folder
    
    ##### Create the kernel and model #####
    pred_start_time = time.time()
    predicted_seq_imgs, predicted_seq_vars = run_prediction(seq_num)
    print("Time taken for prediction: " + str(seq_num) + " : " + str(time.time() - pred_start_time))
    # save the predicted values as a gif
    gif_savepath = "mean_var_propagation_" + str(save_filename_suffix) + ".gif"
    gif_savepath = os.path.join(gif_folder, gif_savepath)
    save_imageseq_gif(gif_savepath, 
                      prediction_steps, 
                      xdataset=None, 
                      ydataset=test_yimages, 
                      predicted=predicted_seq_imgs, 
                      predicted_var=predicted_seq_vars, 
                      show_difference=True, 
                      show_colorbar=True, 
                      show_error_graph=True, 
                      show_whole_error_graph=True, 
                      show_stds_off=True, 
                      figsize=(16, 2.1), 
                      duration=1)

print("Completed")


