"""
In this script, a neural network model is created to simulate the propagation
of a wave by using the PyTorch library.

To model the waves, a grid of Prediction Kernels (PKs) is initialized that are
laterally connected to propagate spatial information. Temporal information is
processed by the PKs that consist of some feed forward layers and a long-short
term memory (LSTM) core. A crucial component of this model is that all PKs share
weights and thus all realize the same computations, just differently
parameterized by dynamic and optional static inputs
"""

import numpy as np
import torch as th
import torch.nn as nn
import glob
import os
import time
import multiprocessing
import kernel_variables
import kernel_net
import configuration as cfg
import helper_functions as helpers

th.set_num_threads(1)

# Hide the GPU(s) in case the user specified to use the CPU in the config file
if cfg.DEVICE == "CPU":
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = ""


def optimize_model(pk_batches, data_inputs, data_targets, device, noise_ratio,
                   rep, lstm_size):
    
    # Set a new seed for the random number generator
    th.seed()

    # Set up the parameter and tensor classes
    params = kernel_variables.KernelParameters(
        pk_batches=pk_batches,
        device=device
    )
    tensors = kernel_variables.KernelTensors(params=params)

    # Initialize and set up the kernel network
    net = kernel_net.KernelNetwork(
        params=params,
        tensors=tensors
    )

    #
    # Set up the optimizer and the criterion (loss)
    optimizer = th.optim.Adam(net.parameters(), lr=cfg.LEARNING_RATE)
    criterion = nn.MSELoss()

    indices = np.arange(len(data_inputs))
    
    #
    # Start the training and iterate over all epochs
    for epoch in range(cfg.EPOCHS):

        np.random.shuffle(indices)

        sequence_errors = []

        # Iterate over all training batches
        for index in indices:

            # Get the current network input and target pair
            inputs, targets = data_inputs[index], data_targets[index]

            # Evaluate and train the network for the given training data
            mse, _ = helpers.evaluate(
                net=net,
                net_input=inputs,
                net_label=targets,
                params=params,
                tensors=tensors,
                pk_batches=pk_batches,
                criterion=criterion,
                optimizer=optimizer,
            )

            if np.isnan(mse.item()):
                return

            sequence_errors.append(mse.item())

        print(epoch, np.mean(np.array(sequence_errors)))

    # Write the model weights to file
    helpers.save_model_to_file(lstm_size=lstm_size, noise_ratio=noise_ratio,
                               rep=rep, net=net)


def run_training(lstm_size=4, epochs=200, noise_ratio=0.0, repeats=range(10),
                 poolsize=5):

    # Set globally reachable flags for the training
    cfg.TRAINING = True
    cfg.PK_NUM_LSTM_CELLS = lstm_size
    cfg.EPOCHS = epochs
    cfg.NOISE_RATIO = noise_ratio
    cfg.DEVICE = CPU

    # Set device on GPU if specified in the configuration file, else CPU
    device = helpers.determine_device()

    # Compute batch size for the PKs (every PK is processed in a separate batch
    # to parallelize computation)
    pk_batches = cfg.PK_ROWS * cfg.PK_COLS
 
    # Load train data
    data = np.array(np.load("data/train/wave_train.npy"), dtype=np.float32)

    # Prepare the data to be in the correct format
    data = np.reshape(data, (data.shape[0], data.shape[1], pk_batches))
    data = np.expand_dims(data, axis=-1)

    # Split the data into inputs and targets, where inputs are shifted by one
    # time step to the right (t becomes t+1).
    data_inputs = data[:, :-1]
    data_targets = data[:, 1:]

    # Add noise to data
    std_noise = np.std(data_inputs) * noise_ratio
    data_inputs += np.random.normal(loc=0, scale=std_noise,
                                    size=np.shape(data_inputs))

    # Convert the numpy arrays to pytorch tensors
    data_inputs = th.tensor(data_inputs, device=device)
    data_targets = th.tensor(data_targets, device=device)

    # Set up a list of arguments for the multiprocessing
    arguments = list()
    for rep in repeats:
        arguments.append((pk_batches, data_inputs, data_targets, device,
                          noise_ratio, rep, lstm_size))

    # Set up the multiprocessor to perform "poolsize" trainings in parallel
    with multiprocessing.Pool(poolsize) as pool:
        pool.starmap(optimize_model, arguments)
        pool.terminate()
        pool.join()

    print("Training of models with " + str(noise_ratio) + " noise finished.")
