"""
Dynamics on distorted S1, LDNet models

Detailed result comparison is done in s1.py
"""
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
tf.keras.backend.set_floatx('float64')

import utils
import optimization

import time

# ----------------------------
# Choose the D parameter and the relevant epoches
# ----------------------------
# D = 0.0
# num_epochs_BFGS = 1300
D = 0.1
num_epochs_BFGS = 4000
# D = 0.3
# num_epochs_BFGS = 4000
# D = 0.5
# num_epochs_BFGS = 2500

# ----------------------------
# Make S1 dataset
# ----------------------------
s5 = np.sqrt(5)
def make_dyn(K, D):
    # \dot{\theta} = 3/2 - \cos(\theta)
    def dyn(tt, K=K, D=D):
        vv = 2 * np.arctan(np.tan(s5*tt/4)/s5)
        rr = 1 + D*np.cos(K*vv)
        uu = np.array([
            rr*np.cos(vv),
            rr*np.sin(vv)]).T
        return uu
    return dyn

fdyn = make_dyn(3, D)

T = 128
Nt = 3201
tt = np.linspace(0, T, Nt)
dt = tt[1]-tt[0]
Ntrain = (Nt-1)//16

t_trn = np.arange(2*Ntrain)*dt
ut = fdyn(tt)

dataset_trn = {
    't' : t_trn,
    'x' : np.array([-1, 1]),
    'output' : np.array([fdyn(t_trn)])
}
dataset_tst = {
    't' : tt,
    'x' : np.array([-1, 1]),
    'output' : np.array([ut])
}
# ----------------------------
# Tuning parameters for S1 example
# ----------------------------
num_latent_states = 2
num_hidden_nodes = 20
num_epochs_Adam = 200
# ----------------------------

# ----------------------------
# Settings adapted from TestCase_1a of the original paper
# ----------------------------
problem = {
    'space': {
        'dimension' : 1 # 1D problem
    },
    'input_parameters': [],
    'input_signals': [],
    'output_fields': [
        { 'name': 'z' }
    ]
}

normalization = {
    'space': { 'min' : [-1], 'max' : [+1]},
    'time': { 'time_constant' : 2.0 },
    'output_fields': {
        'z': { 'min': -1, 'max': +1 }
    }
}

samples_train = [0]
samples_valid = [0]
samples_tests = [0]
dataset_train = utils.ADR_create_dataset(dataset_trn, samples_train)
dataset_valid = utils.ADR_create_dataset(dataset_trn, samples_valid)
dataset_tests = utils.ADR_create_dataset(dataset_tst, samples_tests) 
# ----------------------------

# ----------------------------
# The rest is the same as TestCase_1a, except that the model architecture is simplified, and we added the timing
# ----------------------------
# We re-sample the time transients with timestep dt and we rescale each variable between -1 and 1.
utils.process_dataset(dataset_train, problem, normalization, dt = dt)
utils.process_dataset(dataset_valid, problem, normalization, dt = dt)
utils.process_dataset(dataset_tests, problem, normalization, dt = dt)

# # For reproducibility (delete if you want to test other random initializations)
# np.random.seed(0)
# tf.random.set_seed(0)

# dynamics network
input_shape = (num_latent_states + len(problem['input_parameters']) + len(problem['input_signals']),)
NNdyn = tf.keras.Sequential([
            tf.keras.layers.Dense(num_hidden_nodes, activation = tf.nn.tanh, input_shape = input_shape),
            tf.keras.layers.Dense(num_latent_states)
        ])

# summary
NNdyn.summary()

# reconstruction network
input_shape = (None, None, num_latent_states + problem['space']['dimension'])
NNrec = tf.keras.Sequential([
            tf.keras.layers.Dense(num_hidden_nodes, activation = tf.nn.tanh, input_shape = input_shape),
            tf.keras.layers.Dense(len(problem['output_fields']))
        ])

# summary
NNrec.summary()

def evolve_dynamics(dataset):
    # intial condition
    state = tf.zeros((dataset['num_samples'], num_latent_states), dtype=tf.float64)
    state_history = tf.TensorArray(tf.float64, size = dataset['num_times'])
    state_history = state_history.write(0, state)
    dt_ref = normalization['time']['time_constant']
    
    # time integration
    for i in tf.range(dataset['num_times'] - 1):
        state = state + dt/dt_ref * NNdyn(state)
        state_history = state_history.write(i + 1, state)

    return tf.transpose(state_history.stack(), perm=(1,0,2))

def reconstruct_output(dataset, states):    
    states_expanded = tf.broadcast_to(tf.expand_dims(states, axis = 2), 
        [dataset['num_samples'], dataset['num_times'], dataset['num_points'], num_latent_states])
    return NNrec(tf.concat([states_expanded, dataset['points_full']], axis = 3))

def LDNet(dataset):
    states = evolve_dynamics(dataset)
    return reconstruct_output(dataset, states)

def MSE(dataset):
    out_fields = LDNet(dataset)
    error = out_fields - dataset['out_fields']
    return tf.reduce_mean(tf.square(error))

def loss(): return MSE(dataset_train)
def MSE_valid(): return MSE(dataset_valid)

trainable_variables = NNdyn.variables + NNrec.variables
opt = optimization.OptimizationProblem(trainable_variables, loss, MSE_valid)

t1 = time.time()
print('training (Adam)...')
opt.optimize_keras(num_epochs_Adam, tf.keras.optimizers.Adam(learning_rate=1e-2))
print('training (BFGS)...')
opt.optimize_BFGS(num_epochs_BFGS)
t2 = time.time()
print(t2-t1)

plt.loglog(opt.iterations_history, opt.loss_train_history, 'o-', label = 'training loss')
plt.loglog(opt.iterations_history, opt.loss_valid_history, 'o-', label = 'validation loss')
plt.axvline(num_epochs_Adam)
plt.xlabel('epochs'), plt.ylabel('MSE')
plt.legend()

# Compute predictions.
t1 = time.time()
out_fields = LDNet(dataset_tests)
t2 = time.time()
print(t2-t1)

# Since the LDNet works with normalized data, we map back the outputs into the original ranges.
out_fields_app = utils.denormalize_output(out_fields, problem, normalization).numpy()
out_fields_ref = utils.denormalize_output(dataset_tests['out_fields'], problem, normalization).numpy()

# ----------------------------
# The post-processing is modified to interface with other baselines.
# ----------------------------
RMSE = np.sqrt(np.mean(np.square(out_fields_app - out_fields_ref)))
print('RMSE:       %1.3e' % RMSE)

app = out_fields_app.squeeze()
ref = out_fields_ref.squeeze()

f, ax = plt.subplots(nrows=2, sharex=True)
ax[0].plot(tt, app[:,0], 'b-')
ax[0].plot(tt, ref[:,0], 'k--')
ax[1].plot(tt, app[:,1], 'b-')
ax[1].plot(tt, ref[:,1], 'k--')

np.save(f'./res/s1_ldn_{int(10*D)}.npy', app)

plt.show()
