"""
This script contains the parametric right-hand side example.

This example corresponds to Example 2 in Section 3 of the manuscript.

"""

import tensorflow as tf
from tensorflow.keras.optimizers import Adam

from domains import UnitInterval, Hyperrectangle, ProductDomain
from integrators import accuracy, RandomDeterministicIntegrator, RandomRandomIntegrator
from derivatives import grad_square, prime
from models import create_model

# global 'hyperparameters' for this experiment
PARAM_DIMENSION    = 6
ACTIVATION         = lambda x: tf.keras.activations.relu(x)**2
HIDDEN_LAYERS      = 4
NETWORK_WIDTH      = 64
ITERATIONS         = 10000
LEARNING_RATE      = 0.001
INTEGRATION_POINTS = 200000
OPTIMIZER          = Adam(learning_rate=LEARNING_RATE)

# computational domain
domain_phys = UnitInterval()
domain_para = Hyperrectangle(
    -tf.ones(PARAM_DIMENSION),
    tf.ones(PARAM_DIMENSION),
    )
domain = ProductDomain(
    domain_para=domain_para, 
    domain_phys=domain_phys,
    )

# integration routine
integrator = RandomDeterministicIntegrator(
    prod_domain=domain, 
    N_param=1000, 
    N_phys=150,
    )

# accurate L2 norm computation routine, used for evaluation
@accuracy(precision=1e-4, max_iter=1000)
def accurate_l2_norm(func):
    base_integrator = RandomRandomIntegrator(
        prod_domain=domain, 
        N_param=1000,
        N_phys=1000,
        )
    return base_integrator(lambda x: func(x)**2)**(0.5)

# accurate H1 norm computation routine, used for evaluation
@accuracy(precision=1e-4, max_iter=200)
def accurate_h1_norm(func):
    base_integrator = RandomRandomIntegrator(
        prod_domain=domain, 
        N_param=1000,
        N_phys=1000,
        )
    return base_integrator(lambda x: func(x)**2 + prime(func,x)**2)**(0.5)

# the parametric right-hand side    
def f(X):
    n = len(X)
    N = tf.shape(X)[1]-2

    x = tf.reshape(X[:,-1], shape=(n,1))
    k_s = tf.reshape(tf.cast(tf.range(0,N+1), dtype=float), shape=(1, N+1))
    x_matrix = x * 3.14159265 * k_s

    alpha_matrix = X[:,0:N+1]
    l2_decay = tf.ones(shape=(1,N+1))
    alpha_decay_matrix = alpha_matrix * l2_decay
    return tf.reshape(tf.reduce_sum(tf.cos(x_matrix) * alpha_decay_matrix, axis = 1), shape=(len(X),1))

# the parametric solution
def u_star(X):
    n = len(X)
    N = tf.shape(X)[1]-2

    x = tf.reshape(X[:,-1], shape=(n,1))
    k_s = tf.reshape(tf.cast(tf.range(0,N+1), dtype=float), shape=(1, N+1))
    x_matrix = x * 3.14159265 * k_s

    _factor = tf.convert_to_tensor([1./(k**2 * 3.14159265**2 + 1.) for k in range(0,N+1)])
    factor = tf.reshape(_factor, shape=(1, N+1))

    alpha_matrix = X[:,0:N+1]
    l2_decay = tf.ones(shape=(1,N+1))
    alpha_decay_matrix = alpha_matrix * l2_decay
    return tf.reshape(tf.reduce_sum(factor * tf.cos(x_matrix) * alpha_decay_matrix, axis = 1), shape=(len(X),1))

# the loss function
def loss_factory(integrator):
    @tf.function
    def DRM_loss(u_theta):
        E1 = 0.5 * integrator(lambda X: prime(u_theta, X)**2 + u_theta(X)**2)
        E2 = integrator(lambda X: -u_theta(X) * f(X))
        return E1 + E2
    return DRM_loss

custom_loss = loss_factory(integrator=integrator)

# the model, we use a MLP with ReLU^2 activation
model = create_model(
    input_dim=PARAM_DIMENSION+1,
    width=NETWORK_WIDTH,
    hidden_layers=HIDDEN_LAYERS,
    activation=ACTIVATION,
    )

# the norms and the error
l2_norm  = accurate_l2_norm(u_star)
h1_norm  = accurate_h1_norm(u_star)
error = lambda x: model(x) - u_star(x)
    
# the training loop
prev_loss = 0.
for iteration in range(ITERATIONS):
    with tf.GradientTape() as tape:
        loss = custom_loss(model)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    OPTIMIZER.apply_gradients(zip(gradients, model.trainable_variables))

    if iteration % 50 == 0:
        print(
            f'Iteration {iteration + 1} '
            f'Loss {loss} '
            f'Diff {prev_loss - loss} '
            f'L2 Error {accurate_l2_norm(error)/l2_norm} '
            f'H1 Error {accurate_h1_norm(error)/h1_norm}'
        )
    prev_loss = loss







