"""
HFNGD
10 dimensional Poisson equation example. Solution given by
  u*(x) = sum(x_{2k-1} * x_{2k})


"""
import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit

from ngrad.domains import Hyperrectangle, HypercubeBoundary
from ngrad.models import mlp, init_params
from ngrad.integrators import EvolutionaryIntegrator
from ngrad.utility import laplace, grid_line_search_factory
from ngrad.inner import model_laplace, model_identity
from ngrad.gram import gram_factory, nat_grad_factory

import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
	"--LM", 
	help="Levenberg-Marquardt regularization", 
	default=1e-03,
	type=float,
	)
parser.add_argument(
	"--iter", 
    help="number of iterations", 
	default=500, 
	type=int,
	)
parser.add_argument(
	"--maxiter", 
    help="number of iterations", 
	default=250, 
	type=int,
	)
parser.add_argument(
	"--tau", 
    help="number of iterations", 
	default=250, 
	type=int,
	)

parser.add_argument(
	"--budget", 
    help="number of iterations", 
	default=3000, 
	type=float,
	)

parser.add_argument(
	"--solver", 
    help="adam", 
	default="adam", 
	type=str,
	)

parser.add_argument(
	"--seed", 
    help="number of iterations", 
	default=0, 
	type=int,
	)
args = parser.parse_args()
jax.config.update("jax_enable_x64", True)

# random seed
seed = args.seed

# domains
dim = 10
interior = Hyperrectangle([(0., 1.) for _ in range(0, dim)])
boundary = HypercubeBoundary(dim)

# integrators
interior_integrator = EvolutionaryIntegrator(interior, key=random.PRNGKey(0), N=4000)
boundary_integrator = EvolutionaryIntegrator(boundary, key= random.PRNGKey(1), N=500)
eval_integrator = EvolutionaryIntegrator(interior, key=random.PRNGKey(0), N=  4000)

ITER=500

tau=args.tau




#%%
# model
activation = lambda x : jnp.tanh(x)
layer_sizes = [dim, 256,256,128,128, 1]
layer_sizes = [dim]+[512]*5+[1]
# layer_sizes = [dim, 50,50,50, 1]
# layer_sizes = [dim, 64, 1]
import jax
import jax.numpy as jnp
from jax import random

# Glorot (Xavier) initialization for layer parameters
def glorot_layer_params(m: int, n: int, key):
    w_key, b_key = random.split(key)
    # Calculate the standard deviation for Glorot normal initialization
    stddev = jnp.sqrt(2.0 / (m + n))
    w = stddev * random.normal(w_key, (n, m))
    b = stddev * random.normal(b_key, (n,))
    return w, b

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [glorot_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

params = init_params(layer_sizes, random.PRNGKey(seed))
model = mlp(activation)
v_model = vmap(model, (None, 0))

# solution
@jit
# def u_star(x):
#     return (jnp.sum(jnp.sin(jnp.pi * x)))

def u_star(x):
    """
    Computes u*(x) for d = 10, where
    u*(x) = sum(x_{2k-1} * x_{2k}) for k = 1, ..., 5 using slicing.
    
    :param x: 10-dimensional input vector
    :return: u*(x)
    """
    # Use slicing to get odd and even indexed elements
    odd_indices = x[::2]   # x_1, x_3, x_5, x_7, x_9 (0-based index)
    even_indices = x[1::2] # x_2, x_4, x_6, x_8, x_10 (0-based index)

    # Compute the sum of products of odd and even indexed elements
    result = jnp.sum(odd_indices * even_indices)
    
    return result



v_u_star = vmap(u_star, (0))
v_grad_u_star = vmap(
    lambda x: jnp.dot(grad(u_star)(x), grad(u_star)(x))**0.5, (0)
    )

# rhs

# gramians


from new_natgrad.gram import gram_factory, gvp_factory

from jax.flatten_util import ravel_pytree




f_params,unravel=ravel_pytree(params)


@jit
def interior_res(params,x):
    return model(params, x) #- u_star(x)


laplace_model = lambda params: laplace(lambda x: model(params, x))
residual = lambda params, x: (laplace_model(params)(x) + f(x))**2.

@jit
def gramian_res(params,x):
    return laplace_model(params)(x)


#%%
def tanh(x):
    return jnp.tanh(x)

def tanh_prime(x):
    return 1.0 - jnp.tanh(x)**2

def tanh_double_prime(x):
    return -2.0 * jnp.tanh(x) * (1.0 - jnp.tanh(x)**2)

activation = tanh
activation_prime = tanh_prime
activation_double_prime = tanh_double_prime
def Laplace(params, x):
    W, b = zip(*params)
    z = x
    """
    Component-wise propagation of first-order and second-order derivatives for Laplacian calculation
    for arbitrary dimensions. Each derivative component retains the same shape as the input vector.

    :param params: List of network parameters (weights and biases)
    :param x: Input vector of arbitrary dimension
    :param activation: Activation function
    :param activation_prime: Derivative of activation function
    :param activation_double_prime: Second derivative of activation function
    :return: Output, first-order derivatives, and Laplacian (second-order diagonal terms)
    """
    n = x.shape[0]  # Dimensionality of the input

    # Initialize first-order derivatives (each component of dz_dx has the shape of x)
    dz_dx = [jnp.zeros_like(x) for _ in range(n)]
    for i in range(n):
        dz_dx[i] = dz_dx[i].at[i].set(1.0)  # Set initial derivatives to reflect partial w.r.t. each component

    # Initialize second-order diagonal derivatives (each d2z_dxx[i] has the shape of x)
    d2z_dxx = [jnp.zeros_like(x) for _ in range(n)]

    # Iterate through each layer (except the last one)
    for w, b in params[:-1]:
        # Linear transformation
        z = jnp.dot(w, z) + b

        # Propagate first-order derivatives (vectorized for each dimension)
        dz_dx = [jnp.dot(w, dz_dx_i) for dz_dx_i in dz_dx]

        # Propagate second-order diagonal terms (Laplacian components)
        d2z_dxx = [jnp.dot(w, d2z_dxx_i) for d2z_dxx_i in d2z_dxx]

        # Non-linear activation derivatives
        sigma_prime = activation_prime(z)
        sigma_double_prime = activation_double_prime(z)

        # Save previous first-order derivatives for chain rule
        dz_dx_old = dz_dx.copy()
        d2z_dxx_old = d2z_dxx.copy()

        # First-order derivative update (component-wise)
        dz_dx = [sigma_prime * dz_dx_i for dz_dx_i in dz_dx_old]

        # Second-order derivative update (Laplacian diagonal terms only, no mixed terms)
        d2z_dxx = [sigma_double_prime * dz_dx_old_i ** 2 + sigma_prime * d2z_dxx_old_i
                   for dz_dx_old_i, d2z_dxx_old_i in zip(dz_dx_old, d2z_dxx_old)]

        # Apply non-linear activation
        z = activation(z)

    # Final linear layer propagation (no activation)
    final_w, final_b = params[-1]
    z = jnp.dot(final_w, z) + final_b
    dz_dx = [jnp.dot(final_w, dz_dx_i) for dz_dx_i in dz_dx]
    d2z_dxx = [jnp.dot(final_w, d2z_dxx_i) for d2z_dxx_i in d2z_dxx]

    # The Laplacian is the sum of the diagonal second-order derivatives (Laplacian = trace of second derivatives)
    # laplacian = sum(jnp.trace(d2z_dxx_i) for d2z_dxx_i in d2z_dxx)

    return jnp.sum(jnp.array(d2z_dxx))
# x_Omega=interior_integrator._x
# Laplace(params, x_Omega[0])

# laplace(lambda x: model(params, x))(x_Omega[0])
#%%
@jit
def interior_res(params,x):
    return Laplace(params, x)#- u_star(x)


# loss
@jit
def interior_loss(params):
    return 0.5 * interior_integrator(lambda x: v_residual(params, x))
# @jit
def boundary_res(params,x):
    return model(params, x) - u_star(x)


v_residual=  jit(vmap(interior_res, (None, 0)))
v_boundary_res=  jit(vmap(boundary_res, (None, 0)))



@jit
def loss(params):
    
    x_Omega=interior_integrator._x
    x_Gamma=boundary_integrator._x
    loss=0.5*jnp.mean((v_residual(params,x_Omega))**2)+0.5*jnp.mean((v_boundary_res(params,x_Gamma))**2)
    
    
    
    
    return loss

loss(params)



#%%

f_params, unravel = ravel_pytree(params)
v=jnp.zeros_like(ravel_pytree(params)[0])+1

#%%
# set up grid line search
grid = jnp.linspace(0, 30, 31)
steps = 0.5**grid
ls_update = grid_line_search_factory(loss, steps)

# errors
error = lambda x: model(params, x) - u_star(x)
v_error = vmap(error, (0))
v_error_abs_grad = vmap(
        lambda x: jnp.dot(grad(error)(x), grad(error)(x))**0.5
        )

def l2_norm(f, integrator):
    return integrator(lambda x: (f(x))**2)**0.5

norm_sol_l2 = l2_norm(v_u_star, eval_integrator)
norm_sol_h1 = norm_sol_l2 + l2_norm(v_grad_u_star, eval_integrator)    









from jax.tree_util import Partial



#%%


import json


def save_checkpoint(params, loss_value, l2_error, h1_error, step, iteration, kit, file_name="checkpoints.json"):
    # Create a dictionary to hold checkpoint data
    checkpoint_data = {
        # 'timestamp': str(datetime.now()),
        'iteration': int(iteration),
        'loss': float(loss_value),
        'l2_error': float(l2_error),
        'h1_error': float(h1_error),
        'step': float(step),
        'cg_iteration': kit
    }

 

    # Save updated checkpoint data to the file
    with open(file_name, 'w') as f:
        json.dump(checkpoint_data, f, indent=4)
# set up grid line search
grid = jnp.linspace(0, 30, 31)
steps = 0.5**grid
ls_update = grid_line_search_factory(loss, steps)

# errors
error = lambda x: model(params, x) - u_star(x)
v_error = vmap(error, (0))
v_error_abs_grad = vmap(
        lambda x: jnp.dot(grad(error)(x), grad(error)(x))**0.5
        )

def l2_norm(f, integrator):
    return integrator(lambda x: (f(x))**2)**0.5

norm_sol_l2 = l2_norm(v_u_star, eval_integrator)
norm_sol_h1 = norm_sol_l2 + l2_norm(v_grad_u_star, eval_integrator)    









#%%
import jax
import jax.numpy as jnp
from jax import grad, jit, value_and_grad
import optax
from jax.flatten_util import ravel_pytree
import timeit
import json

# Initialize a dictionary to store results
results = {
    'iterations': [],
    'loss': [],
    'l2_errors': [],
    'optim_time': []
}

# Define the loss function (replace this with your specific loss function)

time_budget = args.budget
# Select optimizer: 'adam', 'lbfgs', or 'sgd'
selected_optimizer = 'adam'  # Choose the optimizer here (replace with `args.solver` in your case)

# Create learning rate schedule for optax (replace with your own schedule)
def create_learning_rate_schedule(initial_lr, warmup_steps=0, decay_steps=10000, decay_rate=0.1):
    """
    Creates a learning rate schedule with optional linear warmup and exponential decay.
    
    :param initial_lr: Initial learning rate.
    :param warmup_steps: Number of warmup steps before decay starts.
    :param decay_steps: Number of steps over which the learning rate decays.
    :param decay_rate: Multiplicative factor of learning rate decay.
    :return: A learning rate schedule.
    """
    schedule = []
    
    if warmup_steps > 0:
        # Linear warmup schedule
        warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=initial_lr, transition_steps=warmup_steps)
        schedule.append(warmup_schedule)
    
    # Exponential decay schedule
    decay_schedule = optax.exponential_decay(init_value=initial_lr, transition_steps=decay_steps, decay_rate=decay_rate)
    schedule.append(decay_schedule)
    
    # Combine warmup and decay schedules
    if len(schedule) > 1:
        return optax.join_schedules(schedule, boundaries=[warmup_steps])
    else:
        return schedule[0]

# Create a learning rate schedule
lr_schedule = create_learning_rate_schedule(
    initial_lr=1e-3,
    warmup_steps=1000,  # Optional warmup steps
    decay_steps=20000,  # Total decay steps
    decay_rate=0.5     # Decay rate
)
import time

selected_optimizer=args.solver
# Initialize optimizer and state
if selected_optimizer == 'adam':
    optimizer = optax.adam(learning_rate=lr_schedule)
    state = optimizer.init(params)

    @jit
    def optimize_step(params, state):
        loss_value, gradients = value_and_grad(loss)(params)
        updates, state = optimizer.update(gradients, state)
        params = optax.apply_updates(params, updates)
        return params, state, loss_value

elif selected_optimizer == 'lbfgs':
    from jaxopt import LBFGS

    optimizer = LBFGS(fun=loss, tol=1e-5, maxiter=50000)
    state = optimizer.init_state(params)
    @jit
    def optimize_step(params, state):
        params, state = optimizer.update(params, state)
        loss_value = loss(params)
        return params, state, loss_value

elif selected_optimizer == 'sgd':
    optimizer = optax.sgd(learning_rate=lr_schedule)
    state = optimizer.init(params)

    @jit
    def optimize_step(params, state):
        loss_value, gradients = value_and_grad(loss)(params)
        updates, state = optimizer.update(gradients, state)
        params = optax.apply_updates(params, updates)
        return params, state, loss_value

# Initialization
 # Replace with actual model parameters
ITER = 50000  # Total number of iterations
start_time0 = timeit.default_timer()
elapsed_time = 0
iteration = 0
# Optimization loop
while elapsed_time < time_budget:
    start_time = timeit.default_timer()

    # Perform the optimization step
    params, state, loss_value = optimize_step(params, state)
    elapsed = timeit.default_timer() - start_time

    # Record checkpoint data every 100 iterations
    if iteration % 100 == 0:
        l2_error = l2_norm(v_error, eval_integrator)

        # Store the results
        results['iterations'].append(int(iteration))
        results['loss'].append(float(loss_value))
        results['l2_errors'].append(float(l2_error))
        results['optim_time'].append(float(timeit.default_timer() - start_time0))

        print(
            f'Iteration: {iteration}, Loss: {loss_value}, '
            f'L2 Error: {l2_error}, Iteration Time: {elapsed}, '
            f'Optimization Time: {timeit.default_timer() - start_time0}'
        )

        # Save the progress periodically
        if iteration % 1000 == 0:
            timestamp = time.strftime("%Y%m%d-%H%M%S")
            # filename = f'optimization_results_{selected_optimizer}_{str(args)}.json'
            # with open(filename, 'w') as f:
            #     json.dump(results, f, indent=4)

    # Update elapsed time
    elapsed_time = timeit.default_timer() - start_time0
    iteration += 1

# Save final results after the entire optimization loop
timestamp = time.strftime("%Y%m%d-%H%M%S")
final_filename = f'final_optimization_results_{selected_optimizer}_{str(args)}.json'
with open(final_filename, 'w') as f:
    json.dump(results, f, indent=4)

print(f"Final results saved to {final_filename}")















