"""
In this script, an LSTM network is trained to predict an MSO (multi
superimposed oscillator) signal. The trained model weights are finally written
to file.
"""


import numpy as np
import torch as th
from torch.utils.data import TensorDataset, DataLoader
import os
from lstm_model import LSTM
import tools


def run_training(number_of_oscillators, hidden_size=32, learning_rate=0.001,
                 adam_beta1=0.9, adam_beta2=0.999, epochs=100, batch_size=128,
                 noise_ratio=0.0, repeats=range(10)):

    #
    # Specify parameters for training a model

    model_path = "models/"
    platform = "GPU"  # can be "GPU" - if available - or "CPU"

    input_size = 1
    output_size = 1

    #
    # Set up the model

    # Limit the number of threads to one (resulting in reasonable acceleration
    # for CPU training)
    th.set_num_threads(1)

    # Determine the tensor type and whether to run the model on CPU or GPU
    device = th.device(
        "cuda:0" if th.cuda.is_available() and platform == "GPU" else "cpu"
    )

    #
    # Data preparation

    # Load train data
    data = np.load("data/train/mso-" + str(number_of_oscillators)
                   + "_train.npy")

    # Split the data into inputs and targets, where inputs are shifted by one
    # time step to the right (t becomes t+1).
    data_inputs = data[:, :-1]
    data_targets = data[:, 1:]

    # Add noise to data
    std_noise = np.std(data_inputs) * noise_ratio
    data_inputs += np.random.normal(loc=0, scale=std_noise,
                                          size=np.shape(data_inputs))

    # Define data loaders for batch creation
    train_loader = DataLoader(
        dataset=TensorDataset(th.tensor(data_inputs,
                                        dtype=th.float32),
                              th.tensor(data_targets,
                                        dtype=th.float32)),
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True
    )
    batches_per_epoch = len(train_loader)

    # Get the sequence length of the data
    sequence_length = len(data_inputs[0])

    #
    # Perform training

    # Train the model "repeats" times
    for rep in repeats:

        # Initialize the model
        model = LSTM(_input_size=input_size,
                     _hidden_size=hidden_size,
                     _output_size=output_size).to(device)

        # Set up a mean squared error criterion and an Adam optimizer
        criterion = th.nn.MSELoss()
        optimizer = th.optim.Adam(params=model.parameters(),
                              lr=learning_rate,
                              betas=(adam_beta1, adam_beta2))

        # Iterate over all epochs
        for epoch in range(epochs):

            epoch_loss = 0

            # Iterate over all training batches
            for (inputs, targets) in train_loader:

                # Move the tensors to the appropriate device (GPU/CPU)
                inputs = inputs.to(device=device)
                targets = targets.to(device=device)

                # Determine the size of the current batch (relevant since the
                # last batch might have a different size depending on the
                # remaining samples)
                real_batch_size = inputs.size()[0]

                # Set the gradients back to zero
                optimizer.zero_grad()

                # Initialize an empty list to store the model outputs
                outputs = []

                # Initialize a zero state tuple for the LSTM
                lstm_h = th.zeros(real_batch_size, hidden_size, device=device)
                lstm_c = th.zeros(real_batch_size, hidden_size, device=device)
                lstm_state = [lstm_h, lstm_c]

                # Iterate over the training sequence length
                for t in range(sequence_length):

                    # Select the current time step's network input from the
                    # batch
                    net_input = inputs[:, t:t + 1]

                    # Forward the input through the model
                    output, lstm_state = model.forward(_net_input=net_input,
                                                       _lstm_state=lstm_state)

                    # Append the model's output to the outputs list
                    outputs.append(output)

                # Reformat the outputs for loss calculation
                outputs = th.stack(outputs)

                # Calculate the error and perform a weight update
                loss = criterion(outputs[:, :, 0], targets.t())
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= batches_per_epoch

            # Dump the current epoch and loss to console
            print(epoch, epoch_loss)

        #
        # Write weights to file

        # Create the directory for the weights
        os.makedirs(model_path, exist_ok=True)

        model_name = tools.generate_model_name(
            model_type="LSTM", hidden_size=32,
            experiment_tag="mso-" + str(number_of_oscillators),
            noise_ratio=noise_ratio, rep=rep
        )

        # Write the model weights to file
        th.save(model.state_dict(), "models/" + model_name)
