"""
In this script, an LSTM network is used/tested to predict an MSO (multi
superimposed oscillator) signal. The trained model weights are loaded from
file.
"""


import numpy as np
import torch as th
from torch.utils.data import TensorDataset, DataLoader
import os
from lstm_model import LSTM
import matplotlib.pyplot as plt
import tools
import sys
sys.path.append("../../active_tuning")
import active_tuning


def load_model(input_size, hidden_size, output_size, device, model_name):
	
	# Initialize the model
	model = LSTM(_input_size=input_size,
				 _hidden_size=hidden_size,
				 _output_size=output_size).to(device)

	# Load pretrained weights into the model
	model.load_state_dict(th.load("models/" + model_name,
						  map_location=device))

	return model

model_path = "models/"
platform = "GPU"  # can be "GPU" - if available - or "CPU"

input_size = 1
output_size = 1
hidden_size = 32

noise_ratio = 0.1

plot_length = 200

washout = 2000



# Limit the number of threads to one (resulting in reasonable acceleration for
# CPU training)
th.set_num_threads(1)

# Determine the tensor type and whether to run the model on CPU or GPU
device = th.device(
	"cuda:0" if th.cuda.is_available() and platform == "GPU" else "cpu"
)


#
# Get statistics from training data.
#
data_train = np.load("data/train/mso-5_train.npy")[0]

std_noise = np.std(data_train) * noise_ratio

data_test = np.load("data/test/mso-5_test.npy")

print(np.std(data_train))
print(noise_ratio)
print(std_noise)

# Grab one example.
net_inputs = th.tensor(data_test[1:10, :plot_length], device=device)

# Initialize a zero array to store the network outputs
net_outputs = th.zeros_like(net_inputs, device=device)
net_targets = th.zeros_like(net_inputs, device=device)
net_outputs_at = th.zeros_like(net_inputs, device=device)
observations = th.zeros_like(net_inputs, device=device)

# Get batch size (all test samples) and sequence length of the data
batch_size, sequence_length = net_inputs.size()

# Load model
model_name = tools.generate_model_name(
    model_type="LSTM", hidden_size=hidden_size,
    experiment_tag="mso-" + str(5),
    noise_ratio=0.0, rep=1
)

model_name_at = tools.generate_model_name(
    model_type="LSTM", hidden_size=hidden_size,
    experiment_tag="mso-" + str(5),
    noise_ratio=0.0, rep=1
)



# Initialize the model
model = load_model(
    input_size=input_size, hidden_size=hidden_size,
    output_size=output_size, device=device, model_name=model_name
)

model_at = load_model(
    input_size=input_size, hidden_size=hidden_size,
    output_size=output_size, device=device, model_name=model_name_at
)


# Initialize a zero state tuple for the LSTM
lstm_h = th.zeros(batch_size, hidden_size, device=device)
lstm_c = th.zeros(batch_size, hidden_size, device=device)
lstm_state = [lstm_h, lstm_c]

lstm_h_at = 0.1 * th.randn(batch_size, hidden_size, device=device)
lstm_c_at = th.zeros(batch_size, hidden_size, device=device)
lstm_state_at = [lstm_h_at, lstm_c_at]

# Initialize a zero initial network input
x_t = th.zeros((batch_size, input_size), device=device)

at_opt_accessor = lambda out, state: state[0]

at = active_tuning.ActiveTuning(model=model_at, initial_model_state=lstm_state_at,
                  initial_model_output=x_t.clone(), tuning_length=1,
                  tf_rate = 0.0,
                  opt_accessor=at_opt_accessor,
                  bias_correction=False,
                  tuning_cycles=1, eta=0.001, beta1=0.9, beta2=0.999,
                  epsilon=1e-8)

for t in range(sequence_length):

    print("Current time step: " + str(t))

    # Current ground truth target value.
    z_t = net_inputs[:, t:t + 1]
    
    # Generate prediction for the current signal z_t based on the input from the
    # previous time step (either noisy teacher forcing input or output feedback).
    # y_t should be z_t in case of an optimal prediction.
    y_t, lstm_state = model.forward(_net_input=x_t, _lstm_state=lstm_state)

    if t < washout:

        # Generate current noisy observation
        s_t = z_t.detach().clone()
        s_t = s_t + th.randn(x_t.size(), device=device) * std_noise

        # Teacher forcing for regular network
        x_t = s_t.detach().clone()
        observations[:, t:t+1] = x_t

        # With Active Tuning we can also incorporate the current
        # observation, since the tune-call generate a prediction
        # for the true signal z_t within the current observation
        # s_t.
        y_at_t, state_at_t = at.tune(observation_t=x_t)

        # When washout is over x_at_t, and state_at_t are used for
        # closed loop prediction.
        x_at_t = y_at_t.detach().clone()

    else:
        # After teacher forcing/active tuning, the at model performs
        # just closed loop prediction ...
        y_at_t, state_at_t = model_at.forward(_net_input=x_at_t, _lstm_state=state_at_t)
        
        x_at_t = y_at_t.detach().clone()

        x_t = y_t.detach().clone()
    
    net_outputs[:, t:t+1] = y_t.detach()
    net_outputs_at[:, t:t+1] = y_at_t.detach()
    net_targets[:, t:t+1] = z_t.detach()
    

net_outputs = net_outputs.cpu().numpy()
net_outputs_at = net_outputs_at.cpu().numpy()
net_targets = net_targets.cpu().numpy()
observations = observations.cpu().numpy()

rmse_output = np.sqrt(np.mean((net_outputs[:, 100:]- net_targets[:, 100:]) ** 2))
rmse_output_at = np.sqrt(np.mean((net_outputs_at[:, 100:]- net_targets[:, 100:]) ** 2))

print("rmse teacher forcing : ", rmse_output)
print("rmse active tuning   : ", rmse_output_at)

# ----

import mso_plot

'''
mso_plot.generatePlot(
    filename= "mso-results.pdf",
    groundtruth = net_targets[0,0:200],
    observations = observations[0,0:200],
    outputs = net_outputs[0,0:200],
    outputs_at = net_outputs_at[0,0:200]
)
'''

# Visualization
plt.plot(range(plot_length), net_targets[0, 0:plot_length], label="Ground truth")
plt.plot(range(plot_length), observations[0, 0:plot_length], label="Ground truth")
plt.plot(range(plot_length), net_outputs[0, 0:plot_length], label="Network outputs")
plt.plot(range(plot_length), net_outputs_at[0, 0:plot_length], label="Network outputs (AT)")
plt.legend()
plt.show()



