import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit
from new_natgrad.domains import Hyperrectangle, HyperrectangleBoundary, HyperrectangleParabolicBoundary
import jax
import jax.numpy as jnp
from jax import random
import jax
import jax.numpy as jnp
import numpy as np
from jax.flatten_util import ravel_pytree
from jax import jit, vmap, lax
import jax.numpy as jnp

from ngrad.domains import Hyperrectangle, HypercubeBoundary
from ngrad.models import mlp, init_params
from ngrad.integrators import EvolutionaryIntegrator
from ngrad.utility import laplace, grid_line_search_factory
from ngrad.inner import model_laplace, model_identity
from ngrad.gram import gram_factory, nat_grad_factory

import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
	"--LM", 
	help="Levenberg-Marquardt regularization", 
	default=1e-03,
	type=float,
	)
parser.add_argument(
	"--iter", 
    help="number of iterations", 
	default=500, 
	type=int,
	)
parser.add_argument(
	"--maxiter", 
    help="number of iterations", 
	default=250, 
	type=int,
	)
parser.add_argument(
	"--seed", 
    help="number of iterations", 
	default=250, 
	type=int,
	)
args = parser.parse_args()
jax.config.update("jax_enable_x64", True)

# random seed
seed = args.seed

# domains
dim = 2
interior = Hyperrectangle([(0., 1.) for _ in range(0, dim)])
boundary = HypercubeBoundary(dim)

# integrators
# x_Omega = EvolutionaryIntegrator(interior, key=random.PRNGKey(0), N=20000)


ITER=500


dim=2
activation = lambda x : jnp.tanh(x)

layer_sizes = [dim, 50,50,50,50, 3]



# Glorot (Xavier) initialization for layer parameters
def glorot_layer_params(m: int, n: int, key):
    w_key, b_key = random.split(key)
    # Calculate the standard deviation for Glorot normal initialization
    stddev = jnp.sqrt(2.0 / (m + n))
    w = stddev * random.normal(w_key, (n, m))
    b = stddev * random.normal(b_key, (n,))
    return w, b

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [glorot_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

params = init_params(layer_sizes, random.PRNGKey(seed))
def mlp(activation): 
    def model(params, inpt):
        hidden = inpt
        for w, b in params[:-1]:
            outputs = jnp.dot(w, hidden) + b
            hidden = activation(outputs)
  
        final_w, final_b = params[-1]
        return jnp.dot(final_w, hidden) + final_b
    return model



model = mlp(activation)
v_model = vmap(model, (None, 0))
# model(params,x)


#%%





#%%
# compute residual



#%%

def tanh(x):
    return jnp.tanh(x)

def tanh_prime(x):
    return 1.0 - jnp.tanh(x)**2

def tanh_double_prime(x):
    return -2.0 * jnp.tanh(x) * (1.0 - jnp.tanh(x)**2)

activation = tanh
activation_prime = tanh_prime
activation_double_prime = tanh_double_prime



#%%
f_params, unravel = ravel_pytree(params)

# Define the interior using the x and y domains
box = [(-0.5, 1.), (-0.5, 1.5)]
interior = Hyperrectangle(box)
rectangle_boundary = HyperrectangleBoundary(intervals=box)

# collocation points
x_Omega = interior.random_integration_points(random.PRNGKey(0), N=400)
x_eval = interior.random_integration_points(random.PRNGKey(0), N=9000)
x_Gamma = rectangle_boundary.random_integration_points(random.PRNGKey(0), N=400)
x_Omegac = interior.random_integration_points(random.PRNGKey(0), N=26)
x_eval = interior.random_integration_points(random.PRNGKey(0), N=9000)
x_Gammac = rectangle_boundary.random_integration_points(random.PRNGKey(0), N=40)
from functools import partial

@partial(jax.jit, static_argnums=( 2, 3))
def resample_hyperrectangle(
    key: jax.random.PRNGKey,
    intervals: list[tuple[float, float]],
    N_interior: int,
    N_boundary: int
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    JIT‐friendly uniform sampling in the interior and on the boundary
    of a d‐dimensional hyperrectangle.

    Args:
      key: a JAX PRNGKey.
      intervals: list of (low, high) bounds, length d.
      N_interior: number of interior points (must be static).
      N_boundary: number of boundary points (must be static).
    Returns:
      x_interior: (N_interior, d)
      x_boundary: (N_boundary, d)
    """
    d = len(intervals)
    lows  = jnp.array([low  for (low, high) in intervals])
    highs = jnp.array([high for (low, high) in intervals])
    span  = highs - lows

    # split once
    k1, k2, k3 = random.split(key, 3)

    # interior
    u_int = random.uniform(k1, (N_interior, d))
    x_int = u_int * span + lows

    # boundary: choose one of 2d faces
    face_ids = random.randint(k2, (N_boundary,), 0, 2*d)
    dims  = face_ids // 2    # which coord is fixed
    sides = face_ids % 2     # 0→low, 1→high

    # sample all coords in [0,1]
    u_bnd = random.uniform(k3, (N_boundary, d))
    x_bnd = u_bnd * span + lows

    # overwrite the chosen coord
    # gather for each sample either lows[dims[i]] or highs[dims[i]]
    fixed_vals = jnp.where(sides[:, None]==0,
                           lows[dims][:, None],
                           highs[dims][:, None])
    idxs = jnp.arange(N_boundary)
    x_bnd = x_bnd.at[idxs, dims].set(fixed_vals.squeeze(-1))

    return x_int, x_bnd


intervals = [(0.0, 1.0)] * dim

key = random.PRNGKey(seed)
x_Omega, x_Gamma = resample_hyperrectangle(
    key,
    box,
    N_interior=400,
    N_boundary=400
)



v=jnp.zeros_like(ravel_pytree(params)[0])+1
Re = 40
nu = 1 / Re
l = 1 / (2 * nu) - jnp.sqrt(1 / (4 * nu ** 2) + 4 * jnp.pi ** 2)

@jit
def p_star(xy):
    x = xy[0]
    return 1 / 2 * (1 - jnp.exp(2 * l * x))






@jit
def u_star(xy):
    x = xy[0]
    y = xy[1]
    
    u = 1 - jnp.exp(l * x) * jnp.cos(2 * jnp.pi * y)
    v = l / (2 * jnp.pi) * jnp.exp(l * x) * jnp.sin(2 * jnp.pi * y)
    return jnp.array([u, v])

#%%


model_u = lambda params, x: model(params, x)[0:2]
model_p = lambda params, x: model(params, x)[2:]



def tanh(x):
    return jnp.tanh(x)

def tanh_prime(x):
    return 1.0 - jnp.tanh(x)**2

def tanh_double_prime(x):
    return -2.0 * jnp.tanh(x) * (1.0 - jnp.tanh(x)**2)

activation = tanh
activation_prime = tanh_prime
activation_double_prime = tanh_double_prime







#%%



Re=40
layer_sizes = [2]+[50]*4+[3]
# layer_sizes = [2]+[64]+[3]
# layer_sizes = [2]+[500]*5+[3]
activation = lambda x : jnp.tanh(x)
# layer_sizes = [2 * K * 2 + 1] +[100]*4+[3]
params = init_params(layer_sizes, random.PRNGKey(seed))
model = mlp(activation)


# x_Omega = interior.random_integration_points(random.PRNGKey(0), N=10000)
# x_eval = interior.random_integration_points(random.PRNGKey(0), N=900)
x=x_Omega[0]
@jax.jit
def derivative_propagation(params, x):
    W, b = zip(*params)
    
    # Initialize derivatives
    z = x
    dz_dx = jnp.eye(len(x))
    d2z_dxx = jnp.zeros((len(x), len(x), len(x)))

    for w, b in params[:-1]:
        # Linear layer
        z = jnp.dot(w, z) + b
        dz_dx = jnp.dot(w, dz_dx)
        d2z_dxx = jnp.einsum('ij,jkl->ikl', w, d2z_dxx)
        
        # Activation layer
        sigma_prime = activation_prime(z)
        sigma_double_prime = activation_double_prime(z)

        # Save the old dz_dx for use in the second-order derivatives
        dz_dx_old = dz_dx.copy()

        # Update first-order derivatives
        dz_dx = sigma_prime[:, None] * dz_dx
        
        # Update second-order derivatives
        term1 = sigma_double_prime[:, None, None] * jnp.einsum('ij,ik->ijk', dz_dx_old, dz_dx_old)
        term2 = sigma_prime[:, None, None] * d2z_dxx
        d2z_dxx = term1 + term2

        # Update z after activation
        z = activation(z)

    # Final linear layer without activation
    final_w, final_b = params[-1]
    z = jnp.dot(final_w, z) + final_b
    dz_dx = jnp.dot(final_w, dz_dx)
    d2z_dxx = jnp.einsum('ij,jkl->ikl', final_w, d2z_dxx)

    return z, dz_dx, d2z_dxx




# @jit
def gramian_res(lin_params, params,Re, x):
    # Compute derivatives for lin_params and params
    u_lin, jac_u_lin, d2z_dxx_lin = derivative_propagation(lin_params, x)

    u, jac_u, _ = derivative_propagation(params, x)
    
    # Compute residual components
    L_u = (1./Re) * jnp.trace(d2z_dxx_lin[:2, :2, :], axis1=1, axis2=2)
    grad_p = jnp.reshape(jac_u_lin[2], (-1,))
    adv_term_lin1 = jnp.sum(u[:2] * jac_u_lin[:2], axis=-1)
    adv_term_lin2 = jnp.sum(u_lin[:2] * jac_u[:2], axis=-1) 
    divres=jnp.trace(jac_u_lin[:2, :2], axis1=0, axis2=1)
    
    momres= L_u - adv_term_lin1 - adv_term_lin2 - grad_p
    return jnp.hstack((divres,momres))




# @jit
def interior_res(params,Re, x):
    # Compute derivatives for params
    u, jac_u, d2z_dxx = derivative_propagation(params, x)
    
    # Extract pressure gradient
    grad_p = jac_u[2]
    
    # Compute residual components
    L_u = (1./Re) * jnp.trace(d2z_dxx[:2, :2, :], axis1=1, axis2=2)
    adv_term = jnp.sum(u[:2] * jac_u[:2], axis=-1)

    
    momres= L_u - adv_term - grad_p
    divres=jnp.trace(jac_u[:2, :2], axis1=0, axis2=1)
    return jnp.hstack((divres,momres))



v_interior_res = vmap(interior_res, (None,None, 0))


# boundary residual
@jit
def boundary_res(params, x):
    return model_u(params, x) - u_star(x)


v_boundary_res_u = vmap(boundary_res, (None,0))
key, subkey = random.split(key)

def loss(params,Re):
    return(
        0.5 * jnp.mean(v_interior_res(params,Re, x_Omega)**2) +
        # 0.5 * jnp.mean(v_div_res(params, x_Omega)**2) +
        4. * 0.5 * jnp.mean(v_boundary_res_u(params, x_Gamma)**2)
    )

LM=1e-5

loss(params,Re)
#%%
from new_natgrad.gram import *




import jax
from jax import vmap, linearize, linear_transpose, jit
from jax.flatten_util import ravel_pytree






f_params, unravel = ravel_pytree(params)
n=f_params.shape[0]


print('n params',n)



    
f_params, unravel = ravel_pytree(params)




def gramian_res_fn(params, x):
    return gramian_res(unravel(params), x).reshape(-1)
def boundary_res_fn(params, x):
    return boundary_res(unravel(params), x).reshape(-1)


scale_Omega = 1 / jnp.sqrt(x_Omega.shape[0])
scale_Gamma = 1 / jnp.sqrt(x_Gamma.shape[0])


import jax
import jax.numpy as jnp


def V(f_params, Re, x, y):
    """
    Interior-interior pairing.
    
    Parameters
    ----------
    f_params : PyTree or array
        The parameter vector (of length 195).
    Re : scalar or array
        A parameter passed to gramian_res.
    xV : array, shape (195,)
        The precomputed derivative (from the x–side) in parameter space.
    y : raw input (for the J–side evaluation).
    
    Returns
    -------
    result : array, shape (3,)
        The dot product computed in parameter space for each output direction.
    """
    # Compute the VJP function for the interior residual.
    
    
    _, v_x = jax.vjp(
        lambda fp: gramian_res(unravel(fp), unravel(f_params), Re, x),
        f_params
    )
    _, v_y = jax.vjp(
        lambda fp: gramian_res(unravel(fp), unravel(f_params), Re, y),
        f_params
    )
    # Build an identity basis for the 3–dimensional output.
    I = jnp.eye(3)  # Shape (3, 3)
    # I=jnp.array([1.])
    # Use vmap to recover the full Jacobian on the J–side; each row is of shape (195,)
    
    J_x = jax.vmap(lambda seed: v_x(seed)[0])(I)
    J_y = jax.vmap(lambda seed: v_y(seed)[0])(I)  # Shape: (3, 195)
    # Compute the dot product in parameter space for each output row.
    # That is, for each row in J_y (a 195–vector), compute dot(xV, row).
    result = jax.vmap(lambda row: jax.lax.dot(J_x, row))(J_y)
    return result


#%%
def V2(f_params,x, y):
    """
    Boundary-interior pairing (J–side: boundary residual).
    
    Parameters
    ----------
    f_params : PyTree or array
        The parameter vector (195–dim).
    u : auxiliary input for boundary_res.
    xV : array, shape (195,)
        The precomputed derivative (from the x–side) computed via boundary_res_fn.
    y : raw input (J–side) for the boundary residual.
    
    Returns
    -------
    result : array, shape (2,)
        The dot product computed in parameter space for each output direction.
    """
    _, v_y = jax.vjp(
        lambda fp: boundary_res(unravel(fp), y),
        f_params
    )
    # For boundary residual, the output is 2–dimensional.
    I2 = jnp.eye(2)  # Shape (2, 2)
    _, v_x = jax.vjp(
        lambda fp: gramian_res(unravel(fp), unravel(f_params), Re, x),
        f_params
    )

    # Build an identity basis for the 3–dimensional output.
    I3 = jnp.eye(3)  # Shape (3, 3)
    
    J_x = jax.vmap(lambda seed: v_x(seed)[0])(I3)
    J_y = jax.vmap(lambda seed: v_y(seed)[0])(I2)  # Shape: (2, 195)
    result = jax.vmap(lambda row: jax.lax.dot(J_x, row))(J_y)  # Shape: (2,)
    return result


#%%

def V3(f_params,x, y):
    """
    Boundary-boundary pairing.
    
    Parameters
    ----------
    f_params : PyTree or array
        The parameter vector (195–dim).
    u : auxiliary input for boundary_res.
    xV : array, shape (195,)
        The precomputed derivative (from the x–side) computed via boundary_res_fn.
    y : raw input (J–side) for the boundary residual.
    
    Returns
    -------
    result : array, shape (2,)
        The dot product computed in parameter space for each output direction.
    """
    
    _, v_x = jax.vjp(
        lambda fp: boundary_res(unravel(fp), x),
        f_params
    )
    _, v_y = jax.vjp(
        lambda fp: boundary_res(unravel(fp),y),
        f_params
    )
    I = jnp.eye(2)  # Output is 2–dimensional.
    J_x = jax.vmap(lambda seed: v_x(seed)[0])(I) 
    J_y = jax.vmap(lambda seed: v_y(seed)[0])(I)  # Shape: (2, 195)
    result = jax.vmap(lambda row: jax.lax.dot(J_x, row))(J_y)  # Shape: (2,)
    return result


def V4(f_params, Re, x, y):
    # Boundary–interior pairing (I–side from boundary, J–side from interior)
    _, v_x = jax.vjp(
        lambda fp: boundary_res(unravel(fp), x),
        f_params,
    )
    I2 = jnp.eye(2)
    J_x = vmap(lambda seed: v_x(seed)[0])(I2)
    _, v_y = jax.vjp(
        lambda fp: gramian_res(unravel(fp), unravel(f_params), Re, y),
        f_params,
    )
    I3 = jnp.eye(3)
    J_y = vmap(lambda seed: v_y(seed)[0])(I3)
    # Resulting shape: (3,2)
    return jnp.dot(J_y, J_x.T)






chunk = 1

# 
from jax import lax

def GG(f_params, Re, x_Omega_I, x_Omega_J):
    # Interior–interior block using double lax.map
    def inner_loop(x, y):
        return V(f_params, Re, x, y)
    
    def outer_loop(y):
        return lax.map(lambda x: inner_loop(x, y), x_Omega_I)
    
    return scale_Omega * scale_Omega * lax.map(
        outer_loop, x_Omega_J, batch_size=x_Omega_J.shape[0] // chunk
    )


def GB(f_params, Re, x_Omega_I, x_Gamma_J):
    # Interior–boundary block using double lax.map
    def inner_loop(x, xGamma):
        return V2(f_params, x, xGamma)
    
    def pairing(xGamma):
        return lax.map(lambda x: inner_loop(x, xGamma), x_Omega_I)
    
    return scale_Omega * scale_Gamma * lax.map(
        pairing, x_Gamma_J, batch_size=x_Gamma_J.shape[0] // chunk
    )


def BB(f_params, x_Gamma_I, x_Gamma_J):
    # Boundary–boundary block using double lax.map
    def inner_loop(x, xGamma):
        return V3(f_params, x, xGamma)
    
    def pairing(xGamma):
        return lax.map(lambda x: inner_loop(x, xGamma), x_Gamma_I)
    
    return scale_Gamma * scale_Gamma * lax.map(
        pairing, x_Gamma_J, batch_size=x_Gamma_J.shape[0] // chunk
    )


def GB_flip(f_params, Re, x_Gamma_I, x_Omega_J):
    # Boundary–interior block (flipped) using double lax.map
    def inner_loop(x, y):
        return V4(f_params, Re, x, y)
    
    def outer_loop(y):
        return lax.map(lambda x: inner_loop(x, y), x_Gamma_I)
    
    return scale_Omega * scale_Gamma * lax.map(
        outer_loop, x_Omega_J, batch_size=x_Omega_J.shape[0] // chunk
    )
@jit
def compute_G_dual(f_params,x_Omega,x_Gamma):
    # Assemble blocks into full S_ij matrix.
    S_ii = GG(f_params, Re, x_Omega, x_Omega)    # (n_Omega_J, n_Omega_I, 3, 3)
    S_ib = GB(f_params, Re, x_Omega, x_Gamma)      # (n_Gamma_J, n_Omega_I, 2, 3)
    S_bi = GB_flip(f_params, Re, x_Gamma, x_Omega) # (n_Omega_J, n_Gamma_I, 3, 2)
    S_bb = BB(f_params, x_Gamma, x_Gamma)          # (n_Gamma_J, n_Gamma_I, 2, 2)

    def flatten_block(block):
        n, m, a, b = block.shape
        return block.transpose(0, 2, 1, 3).reshape(n * a, m * b)

    S_ii_flat = flatten_block(S_ii)
    S_ib_flat = flatten_block(S_ib)
    S_bi_flat = flatten_block(S_bi)
    S_bb_flat = flatten_block(S_bb)

    S_top = jnp.hstack([S_ii_flat, S_bi_flat])
    S_bot = jnp.hstack([S_ib_flat, S_bb_flat])
    S_full = jnp.vstack([S_top, S_bot])
    return S_full.T
#%%


def rl2_norm(f,fs, x_eval):
    return jnp.mean((f(x_eval))**2.)**0.5/ jnp.mean((fs(x_eval))**2.)**0.5
error_u = lambda x: model_u(params, x) 
v_u = vmap(model_u, (None,0))


U_star= vmap(lambda x: u_star(x)[0])

V_star=vmap(lambda x: u_star(x)[1])

v_error_U =  lambda x: v_error_u(x)[:,0:1]

v_error_V = lambda x: v_error_u(x)[:,1:2]


#%%
import timeit

from jax.tree_util import Partial


import lineax as lx
@jax.jit
def solve_lineax(matrix, vector):
    operator = lx.MatrixLinearOperator(matrix)
    solver = lx.QR()  # or lx.AutoLinearSolver(well_posed=None)
    solution = lx.linear_solve(operator, vector, solver)
    return solution.value


#%%





Train=True


its=[]
l2_errors=[]
Rel=[100,400,1000,3000]
import numpy as np

def l2_norm(f, x_eval):
    return jnp.mean((f(x_eval)- u_star)**2.)**0.5/jnp.mean((u_star)**2.)**0.5
checkpoints = {}




grid = jnp.linspace(0, 30, 31)
steps = 0.5**grid


# from new_natgrad.utility import grid_line_search_factory,grid_line_search_factorym
def grid_line_search_factorym(loss, steps):
    
    def loss_at_step(step, params, tangent_params, kwargs):
        updated_params = [(w - step * dw, b - step * db)
                          for (w, b), (dw, db) in zip(params, tangent_params)]
        extracted_kwargs = {k: kwargs[k] for k in kwargs}

        # Pass updated_params and extracted_kwargs to the loss function
        return loss(updated_params, **extracted_kwargs)
    
    # Adjust the vmap call
    # Only the 'step' argument (first position) is vectorized
    v_loss_at_steps = jit(vmap(loss_at_step, in_axes=(0,  None, None,None))) 
    
    # @jit
    def grid_line_search_update(params, tangent_params, **kwargs):
        kwargs_dict = {k: v for k, v in kwargs.items()}  # Prepare kwargs as a dictionary
        losses = v_loss_at_steps(steps, params, tangent_params, kwargs_dict)
        step_size = steps[jnp.argmin(losses)]
        return [(w - step_size * dw, b - step_size * db)
                for (w, b), (dw, db) in zip(params, tangent_params)], step_size
    
    return grid_line_search_update
ls_update = grid_line_search_factorym(loss, steps)

VERBOSE=True







def compute_XTv(f_params, v, Re,x_Omega,x_Gamma):
    _i = lambda X,s: s * jax.jvp(lambda p: jax.vmap(gramian_res, in_axes=(None,None,None,0))(unravel(p),unravel(f_params),Re,X),(f_params,),(v,))[1].reshape(-1)
    _b = lambda X,s: s * jax.jvp(lambda p: jax.vmap(boundary_res, in_axes=(None,0))(unravel(p),X),(f_params,),(v,))[1].reshape(-1)
  
    
  
    return jnp.concatenate([_i(x_Omega,scale_Omega), _b(x_Gamma,scale_Gamma)])

def compute_Xv(f_params, w,  Re,x_Omega,x_Gamma,d_int=3, d_bnd=2):
    n_i, n_b = x_Omega.shape[0], x_Gamma.shape[0]
    idx1 = n_i*3
    w_i, w_b = jnp.split(w, [idx1])
    _i_vjp = lambda X,s,w_: s * jax.vjp(lambda p: jax.vmap(gramian_res, in_axes=(None,None,None,0))(unravel(p),unravel(f_params),Re,X), f_params)[1](w_.reshape(-1,d_int))[0].reshape(-1)
    _b_vjp = lambda X,s,w_: s * jax.vjp(lambda p: jax.vmap(boundary_res, in_axes=(None,0))(unravel(p),X), f_params)[1](w_.reshape(-1,d_bnd))[0].reshape(-1)
    return _i_vjp(x_Omega,scale_Omega,w_i) + _b_vjp(x_Gamma,scale_Gamma,w_b) 




def directional_second_derivative(fun, params, v):
    """
    Return   vᵀ ∇²[fun](params) v      without materialising any Hessian.
    Works for any JAX-pytree params, any output shape fun(params).
    """
    # 1st J·v
    _, first = jax.jvp(fun, (params,), (v,))
    # 2nd vᵀ∇²·v
    _, second = jax.jvp(lambda p: jax.jvp(fun, (p,), (v,))[1],
                        (params,), (v,))
    return second        # same shape as fun(params)

# ---------- fvv for interior + boundary -------------------------------
def compute_fvv(f_params, v,Re,x_Omega,x_Gamma,
                d_int=3, d_bnd=2,
                scale_int=scale_Omega, scale_bnd=scale_Gamma):
    """Directional 2nd derivative of the *stacked* residual vector."""


    # residual maps already vectorised over sample dim -----------------
    res_int = lambda p: jax.vmap(interior_res,  in_axes=(None,None, 0))(unravel(p),Re, x_Omega)\
                        .reshape(-1)          # shape (n_int * d_int,)
    res_bnd = lambda p: jax.vmap(boundary_res, in_axes=(None, 0))(unravel(p), x_Gamma)\
                        .reshape(-1)          # shape (n_bnd * d_bnd,)

    # apply two nested jvp's -------------------------------------------
    fvv_int = scale_int *scale_int * directional_second_derivative(res_int, f_params, v)
    fvv_bnd = scale_bnd *scale_bnd * directional_second_derivative(res_bnd, f_params, v)

    return jnp.concatenate([fvv_int, fvv_bnd])   # shape (m,)


#%%
def compute_Xv(f_params, w,  Re,x_Omega,x_Gamma,d_int=3, d_bnd=2):
    n_iI, n_bI, n_iJ, n_bJ = x_Omega_I.shape[0], x_Gamma_I.shape[0], x_Omega_J.shape[0], x_Gamma_J.shape[0]
    idx1, idx2, idx3 = n_iI*d_int, n_iI*d_int+n_bI*d_bnd, n_iI*d_int+n_bI*d_bnd+n_iJ*d_int
    w_iI, w_bI, w_iJ, w_bJ = jnp.split(w, [idx1, idx2, idx3])
    _i_vjp = lambda X,s,w_: s * jax.vjp(lambda p: jax.vmap(gramian_res, in_axes=(None,None,None,0))(unravel(p),unravel(f_params),Re,X), f_params)[1](w_.reshape(-1,d_int))[0].reshape(-1)
    _b_vjp = lambda uv,X,s,w_: s * jax.vjp(lambda p: jax.vmap(boundary_res, in_axes=(None,0,0))(unravel(p),uv,X), f_params)[1](w_.reshape(-1,d_bnd))[0].reshape(-1)
    return _i_vjp(x_Omega_I,scale_Omega,w_iI) + _b_vjp(uv_I,x_Gamma_I,scale_Gamma,w_bI) + _i_vjp(x_Omega_J,scale_Omega,w_iJ) + _b_vjp(uv_J,x_Gamma_J,scale_Gamma,w_bJ)



#%%


f_params, unravel = ravel_pytree(params)









f_lm_grad = jnp.zeros_like(ravel_pytree(params)[0])

def compute_XTv(f_params, v, Re,x_Omega,x_Gamma):
    _i = lambda X,s: s * jax.jvp(lambda p: jax.vmap(gramian_res, in_axes=(None,None,None,0))(unravel(p),unravel(f_params),Re,X),(f_params,),(v,))[1].reshape(-1)
    _b = lambda X,s: s * jax.jvp(lambda p: jax.vmap(boundary_res, in_axes=(None,0))(unravel(p),X),(f_params,),(v,))[1].reshape(-1)
  
    
  
    return jnp.concatenate([_i(x_Omega,scale_Omega), _b(x_Gamma,scale_Gamma)])

def compute_Xv(f_params, w,  Re,x_Omega,x_Gamma,d_int=3, d_bnd=2):
    n_i, n_b = x_Omega.shape[0], x_Gamma.shape[0]
    idx1 = n_i*3
    w_i, w_b = jnp.split(w, [idx1])
    _i_vjp = lambda X,s,w_: s * jax.vjp(lambda p: jax.vmap(gramian_res, in_axes=(None,None,None,0))(unravel(p),unravel(f_params),Re,X), f_params)[1](w_.reshape(-1,d_int))[0].reshape(-1)
    _b_vjp = lambda X,s,w_: s * jax.vjp(lambda p: jax.vmap(boundary_res, in_axes=(None,0))(unravel(p),X), f_params)[1](w_.reshape(-1,d_bnd))[0].reshape(-1)
    return _i_vjp(x_Omega,scale_Omega,w_i) + _b_vjp(x_Gamma,scale_Gamma,w_b) 




from jax.scipy.linalg import solve_triangular



import jax
import jax.numpy as jnp
from jax import lax
from functools import partial


@partial(jax.jit,
         static_argnames=('trans', 'lower', 'unit_diagonal',
                          'left_side', 'check_finite', 'overwrite_b'))
def solve_triangular(a, b,
                         *,
                         trans: int | str = 0,
                         lower: bool = False,
                         unit_diagonal: bool = False,
                         overwrite_b: bool = False,
                         check_finite: bool = False,
                         left_side: bool = True):
    """
    GPU-only replacement for jax.scipy.linalg.solve_triangular.

    Parameters match the SciPy/JAX-SciPy function, but `overwrite_b`
    and `check_finite` are ignored (they would force a host
    copy/scan anyway).

    Returns
    -------
    x : array with the same shape as `b`
        Solution to  A @ x = b   (or  x @ A = b when left_side=False).
    """
    # ----- turn trans into the two internal boolean flags -------------
    transpose_a = trans in (1, 'T', 't')
    conjugate_a = trans in (2, 'C', 'c')

    # ----- actual GPU kernel ------------------------------------------
    return lax.linalg.triangular_solve(
        a, b,
        left_side=left_side,
        lower=lower,
        transpose_a=transpose_a,
        conjugate_a=conjugate_a,
        unit_diagonal=unit_diagonal,
    )

@jit
def Optim_step(params,f_lm_grad,key):
    
    

    x_Omega, x_Gamma = resample_hyperrectangle(
        key,
        box,
        N_interior=400,
        N_boundary=400
    )
    # new_key, _= jax.random.split(key, 2)
    def loss(params,Re):
        return(
            0.5 * jnp.mean(v_interior_res(params,Re, x_Omega)**2) +
            # 0.5 * jnp.mean(v_div_res(params, x_Omega)**2) +
            4. * 0.5 * jnp.mean(v_boundary_res_u(params, x_Gamma)**2)
        )

    
    ls_update = grid_line_search_factorym(loss, steps)
    f_grads, unravel = ravel_pytree(grad(loss)(params,Re))
    f_params, unravel = ravel_pytree(params)

    
    
    mu = jnp.min(jnp.array([loss(params,Re), LM]))
    


    
    Gb=compute_G_dual(f_params,x_Omega,x_Gamma)
    
    
    Id = jnp.identity(len(Gb))
    Gb = mu * Id + Gb
    
    
    
    
    w=compute_XTv(f_params, f_grads, Re,x_Omega,x_Gamma)
       
    # wdual = solve_lineax(Gb, w)
    
    
    L = jnp.linalg.cholesky(Gb)
    
    
    
    y = solve_triangular(L, w, lower=True)      # solves L y = b
    wdual = solve_triangular(L.T, y, lower=False)   # solves L.T x = y
    
    
    
    

    wprimal = (1/mu)*f_grads - (1/mu)*compute_Xv(f_params, wdual,Re,x_Omega,x_Gamma)
    # nat_grad = unravel(wprimal)
    
    
    fvv=compute_fvv(f_params, v,Re,x_Omega,x_Gamma)
    rhs=-(Gb-mu * Id)@fvv
    ya = solve_triangular(L, rhs, lower=True)      # solves L y = b
    adual = solve_triangular(L.T, ya, lower=False)   # solves L.T x = y
    
   
    
    
    
    a = (1/mu)*compute_Xv(f_params, adual+fvv,Re,x_Omega,x_Gamma)
    
    
    
    ratio = 2.0 * jnp.linalg.norm(a) / jnp.linalg.norm(wprimal)

    mult     = jnp.where(ratio <= 0.5, 0.5, 0.0)   # pure JAX, no int()
    nat_grad_flat = wprimal + mult * a
    nat_grad = unravel(nat_grad_flat)
    
    



    # -------- user‑supplied line search --------------------------------
    # new_params, step_len = ls_update(params, nat_grad)
        
     

    params, actual_step = ls_update(params, nat_grad,Re=Re)
    
    return params, actual_step

# errors
error_u = lambda x: model_u(params, x) - u_star(x)
v_error_u = vmap(error_u, (0))
v_error_u_abs_grad = vmap(
        lambda x: jnp.sum(jacrev(error_u)(x)**2.)**0.5
        )

error_p = lambda x: model_p(params, x) - p_star(x)
v_error_p = vmap(error_p, (0))
v_error_p_abs_grad = vmap(
        lambda x: jnp.sum(jacrev(error_p)(x)**2.)**0.5
        )

def l2_norm(f, x_eval):
    return jnp.mean((f(x_eval))**2.)**0.5
def rl2_norm(f, x_eval):
    return jnp.mean((f(x_eval))**2.)**0.5/ jnp.mean((u_star(x_eval))**2.)**0.5
l2_error_u = l2_norm(v_error_u, x_eval)
h1_error_u = l2_error_u + l2_norm(v_error_u_abs_grad, x_eval)
if VERBOSE == True:
    print(
        f'Before training: loss: {loss(params,Re)} with error '
        f'L2: {l2_error_u} and error H1: {h1_error_u}.'
    )










#%%
import os
import pickle
import timeit
import numpy as np

# --- LOGGING SETUP ---
os.makedirs("runs/kovaz", exist_ok=True)
iterations = []
avg_relative_l2_errors = []
simulation_times = []

seed = args.seed
file_name = f"runs/kovaz/ng_seed_{seed}.pkl"
# ----------------------

# initial guess for NG solve
f_lm_grad = jnp.zeros_like(ravel_pytree(params)[0])

# start the global timer
start_time0 = timeit.default_timer()
iteration = 0

while True:
    # enforce 3000 s wall‐clock budget
    total_elapsed = timeit.default_timer() - start_time0
    if total_elapsed > 3000:
        print(f"Time budget exceeded ({total_elapsed:.1f}s) at iteration {iteration}, stopping.")
        break

    # per‐step timer
    step_start = timeit.default_timer()

    # perform one NG optimization step
    key, subkey = random.split(key)
    params, actual_step = Optim_step(params, f_lm_grad, key)

    # measure step duration
    elapsed = timeit.default_timer() - step_start

    # compute relative L2 errors for U and V, then average

    # ----------------------

    if VERBOSE and iteration % 50 == 0:
        err_u = float(rl2_norm(v_error_U, x_eval))
        err_v = float(rl2_norm(v_error_V, x_eval))
        avg_err = 0.5 * (err_u + err_v)

        # --- LOGGING APPEND ---
        iterations.append(iteration)
        avg_relative_l2_errors.append(avg_err)
        simulation_times.append(elapsed)
        print(
            f"NG Iter {iteration} | Avg Rel L2 Err (U,V): {avg_err:.3e} | "
            f"Step time: {elapsed:.2f}s | Total time: {total_elapsed:.2f}s"
        )

    iteration += 1  # increment counter

# save all logged data
results = {
    "seed": seed,
    "iterations": iterations,
    "avg_relative_l2_errors": avg_relative_l2_errors,
    "simulation_times": simulation_times,
    "params": params
}
with open(file_name, "wb") as f:
    pickle.dump(results, f)

print(f"All results saved to {file_name}")


#%%
