"""
In this script, an LSTM network is used/tested to predict an MSO (multi
superimposed oscillator) signal. The trained model weights are loaded from
file.
"""


import numpy as np
import torch as th
from torch.utils.data import TensorDataset, DataLoader
import os
from lstm_model import LSTM
import matplotlib.pyplot as plt
import tools
import sys
sys.path.append("../../active_tuning")
import active_tuning


def load_model(input_size, hidden_size, output_size, device, model_name):
    
    # Initialize the model
    model = LSTM(_input_size=input_size,
                 _hidden_size=hidden_size,
                 _output_size=output_size).to(device)

    # Load pretrained weights into the model
    model.load_state_dict(th.load("models/" + model_name,
                          map_location=device))

    return model


def evaluate(number_of_oscillators, hidden_size=32, noise_ratio_training=0.0,
             noise_ratio_inference=0.0, repeats=range(10), washout=400,
             skip_error_length=100, use_active_tuning=False, at_tf_rate=0.0,
             at_tuning_length=5, at_tuning_cycles=10, at_eta=0.01,
             at_beta1=0.9, at_beta2=0.999, at_epsilon=1e-8):
    
    #
    # Specify parameters for training a model

    model_path = "models/"
    platform = "GPU"  # can be "GPU" - if available - or "CPU"

    input_size = 1
    output_size = 1

    # Limit the number of threads to one (resulting in reasonable acceleration for
    # CPU training)
    th.set_num_threads(1)

    # Determine the tensor type and whether to run the model on CPU or GPU
    device = th.device(
        "cuda:0" if th.cuda.is_available() and platform == "GPU" else "cpu"
    )

    #print("load data...")
        
    # Get statistics from training data.
    data_train = np.load("data/train/mso-" + str(number_of_oscillators)
                         + "_train.npy")
    std_noise = np.std(data_train) * noise_ratio_inference
    
    # Load test data
    data_test = th.tensor(
        np.load("data/test/mso-" + str(number_of_oscillators) + "_test.npy"),
        device=device
    )
    #print("done.")
    
    #data_inputs = data_test[0:1, :-1]
    data_inputs = data_test[:, :-1]

    # Get batch size (all test samples) and sequence length of the data
    batch_size, sequence_length = data_inputs.size()

    # Initialize an empty error list to store the rmse scores for all models
    rmse_list = list()
    
    # Iterate over all repetitions to evaluate all trained models
    for rep in repeats:
        #print("rep: ", rep)

        # Load models
        model_name = tools.generate_model_name(
            model_type="LSTM", hidden_size=hidden_size,
            experiment_tag="mso-" + str(number_of_oscillators),
            noise_ratio=noise_ratio_training, rep=rep
        )

        # Initialize the model
        model = load_model(
            input_size=input_size, hidden_size=hidden_size,
            output_size=output_size, device=device, model_name=model_name
        )

        # Initialize a zero state tuple for the LSTM
        if use_active_tuning:
            lstm_h = 0.1 * th.randn(batch_size, hidden_size, device=device)
        else:
            lstm_h = th.zeros(batch_size, hidden_size, device=device)

        lstm_c = th.zeros(batch_size, hidden_size, device=device)
        lstm_state = [lstm_h, lstm_c]
        
        # Initialize a zero initial network input
        x_t = th.zeros((batch_size, input_size), device=device)

        # Initialize a zero array to store the network outputs
        net_outputs = th.zeros_like(data_inputs, device=device)
        net_targets = th.zeros_like(data_inputs, device=device)
        net_observations = th.zeros_like(data_inputs, device=device)

        # Initialize active tuning if required
        if use_active_tuning:

            # Accessor function that returns the state which is to be optimized
            at_opt_accessor = lambda out, state : state[0]

            # Create an active tuning instance
            at = active_tuning.ActiveTuning(
                model=model, initial_model_state=lstm_state,
                initial_model_output=x_t.clone(),
                tuning_length=at_tuning_length, tf_rate=at_tf_rate,
                opt_accessor=at_opt_accessor, tuning_cycles=at_tuning_cycles,
                eta=at_eta, beta1=at_beta1, beta2=at_beta2, epsilon=at_epsilon
            )

        # Evaluate model
        for t in range(sequence_length):
            # Get the current observation that should be produced by the model
            z_t = data_inputs[:, t:t+1]
        
            # Generate current noisy observation
            s_t = z_t.detach().clone()
            
            # Enable this for normal distributed noise.
            s_t = s_t + th.randn(x_t.size(), device=device) * std_noise

            # Enable this for salt and pepper noise.
            #s_t = s_t + (1.0 - (th.randint(size=x_t.size(), low=0, high=2) * 2.0)) * std_noise

            if t < washout:
                if use_active_tuning:

                    # With Active Tuning we can also incorporate the current
                    # observation, since the tune-call generate a prediction
                    # for the true signal z_t within the current observation
                    # s_t.
                    x_t = s_t.detach().clone()
                    y_t, lstm_state = at.tune(observation_t=x_t)
                else:
                    # Generate prediction for the current signal z_t based on the input from the
                    # previous time step (either noisy teacher forcing input or output feedback).
                    # y_t should be z_t in case of an optimal prediction.
                    y_t, lstm_state = model.forward(_net_input=x_t, _lstm_state=lstm_state)
            
                    # Teacher forcing for regular network
                    x_t = s_t.detach().clone()
            else:
                # After teacher forcing/active tuning, the at model performs
                # just closed loop prediction ...
                y_t, lstm_state = model.forward(_net_input=x_t, _lstm_state=lstm_state)
                x_t = y_t.detach().clone()

            # Store the current network output and noisy observation.
            net_observations[:, t:t+1] = s_t.detach()
            net_targets[:, t:t+1] = z_t.detach()
            net_outputs[:, t:t+1] = y_t.detach()

        # Convert tensors to numpy arrays.
        net_outputs = net_outputs.cpu().numpy()
        net_targets = net_targets.cpu().numpy()
        net_observations = net_observations.cpu().numpy()
        
        net_outputs_loss = net_outputs[:, skip_error_length:]
        net_targets_loss = net_targets[:, skip_error_length:]

        # Calculate root mean squared error between network outputs and targets
        rmse = np.sqrt(np.mean(np.square(net_outputs_loss - net_targets_loss)))
        rmse_list.append(rmse)
        # Visualization
        #plt.plot(range(sequence_length), net_targets[0, :], label="Ground truth")
        #plt.plot(range(sequence_length), net_outputs[0, :], label="Network outputs")
        #plt.legend()
        #plt.show()

    # Convert error list to numpy array and return mean and standard deviation
    rmse_list = np.array(rmse_list)

    return rmse_list

if __name__ == "__main__":

    a = evaluate(
        repeats=[0],
        #repeats=range(10),
        number_of_oscillators=5,
        noise_ratio_training=0.05,
        noise_ratio_inference=0.2,
        use_active_tuning=True,
        at_tuning_length=16,
        at_tuning_cycles=10,
        at_tf_rate=0.0,
        at_eta=0.004,
        at_beta1=0.5,
        at_beta2=0.9
    )

    print(a)
    print(np.mean(a), "+-", np.std(a), "     ", np.median(a))
