import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit, jacrev, value_and_grad
from functools import partial

# Assuming these are your necessary imports for the domain and model
from new_natgrad.domains import Hyperrectangle, HyperrectangleBoundary
from jax.flatten_util import ravel_pytree
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
    "--seed",
    help="Random seed for reproducibility",
    default=0,
    type=int,
)
parser.add_argument("--solver", choices=["adam", "sgd", "bfgs"], default="bfgs", help="Optimizer to use")
parser.add_argument("--adam_lr", default=1e-3, type=float, help="Adam learning rate")
parser.add_argument("--initial_lr", default=1e-2, type=float, help="Initial learning rate for SGD schedule")
parser.add_argument("--warmup_steps", default=1000, type=int, help="Warmup steps for SGD schedule")
parser.add_argument("--decay_steps", default=20000, type=int, help="Decay steps for SGD schedule")
parser.add_argument("--decay_rate", default=0.5, type=float, help="Decay rate for SGD schedule")
parser.add_argument("--momentum", default=0.9, type=float, help="Momentum for SGD")
parser.add_argument("--max_grad_norm",default=1.0, type=float, help="Max gradient norm for SGD clipping")
parser.add_argument("--bfgs_maxiter", default=5, type=int, help="Max inner iterations per L-BFGS step")

args, unknown = parser.parse_known_args()
if unknown:
    print(f"Unknown arguments: {unknown}")

jax.config.update("jax_enable_x64", True)

key = random.PRNGKey(args.seed)

dim = 2
activation = lambda x : jnp.tanh(x)
layer_sizes = [dim, 50,50,50,50, 3]

def glorot_layer_params(m: int, n: int, key_glorot):
    w_key, b_key = random.split(key_glorot)
    stddev = jnp.sqrt(2.0 / (m + n))
    w = stddev * random.normal(w_key, (n, m))
    b = stddev * random.normal(b_key, (n,))
    return w, b

def init_params_custom(sizes, key_init):
    keys_init = random.split(key_init, len(sizes))
    return [glorot_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys_init)]

params_init_key, key = random.split(key)
# This is the master PyTree for parameters
current_params_pytree = init_params_custom(layer_sizes, params_init_key)
# For L-BFGS, we'll primarily work with a flat version.
# And for other optimizers, they typically take PyTrees.
# We'll get the initial flat version and unravel function here.
initial_params_flat, unravel = ravel_pytree(current_params_pytree)


def mlp_custom(activation_fn):
    def model_fn(p_mlp, inpt):
        hidden = inpt
        for w, b_layer_mlp in p_mlp[:-1]:
            outputs = jnp.dot(w, hidden) + b_layer_mlp
            hidden = activation_fn(outputs)
        final_w, final_b = p_mlp[-1]
        return jnp.dot(final_w, hidden) + final_b
    return model_fn

model_nn = mlp_custom(activation)
# v_model_nn = vmap(model_nn, (None, 0)) # Not directly used later, model_u/model_p are

from jax.flatten_util import ravel_pytree # Already imported but good to ensure

def tanh(x): return jnp.tanh(x)
def tanh_prime(x): return 1.0 - jnp.tanh(x)**2
def tanh_double_prime(x): return -2.0 * jnp.tanh(x) * (1.0 - jnp.tanh(x)**2)

activation_orig = tanh
activation_prime = tanh_prime
activation_double_prime = tanh_double_prime

box = [(-0.5, 1.0), (-0.5, 1.5)]
interior = Hyperrectangle(box)
rectangle_boundary = HyperrectangleBoundary(intervals=box)

N_interior_train = 400
N_boundary_train = 400
N_eval = 9000

eval_key, key = random.split(key)
omega_key_init, key = random.split(key)
gamma_key_init, key = random.split(key)
x_Omega_init = interior.random_integration_points(omega_key_init, N=N_interior_train)
x_Gamma_init = rectangle_boundary.random_integration_points(gamma_key_init, N=N_boundary_train)
x_eval = interior.random_integration_points(eval_key, N=N_eval)

@partial(jax.jit, static_argnums=(2, 3))
def resample_hyperrectangle_local(
    key_resample: jax.random.PRNGKey,
    intervals_resample: list[tuple[float, float]],
    N_interior_resample: int,
    N_boundary_resample: int
) -> tuple[jnp.ndarray, jnp.ndarray]:
    d_resample = len(intervals_resample)
    lows_resample = jnp.array([low for (low, high) in intervals_resample])
    highs_resample = jnp.array([high for (low, high) in intervals_resample])
    span_resample = highs_resample - lows_resample
    k1, k2, k3 = random.split(key_resample, 3)
    u_int = random.uniform(k1, (N_interior_resample, d_resample))
    x_int = u_int * span_resample + lows_resample
    face_ids = random.randint(k2, (N_boundary_resample,), 0, 2*d_resample)
    dims_resample = face_ids // 2
    sides = face_ids % 2
    u_bnd = random.uniform(k3, (N_boundary_resample, d_resample))
    x_bnd = u_bnd * span_resample + lows_resample
    fixed_vals = jnp.where(sides[:, None]==0,
                            lows_resample[dims_resample][:, None],
                            highs_resample[dims_resample][:, None])
    idxs = jnp.arange(N_boundary_resample)
    x_bnd = x_bnd.at[idxs, dims_resample].set(fixed_vals.squeeze(-1))
    return x_int, x_bnd

Re = 40.0
nu = 1.0 / Re
l_kovaz = 1.0 / (2.0 * nu) - jnp.sqrt(1.0 / (4.0 * nu ** 2) + 4.0 * jnp.pi ** 2)

@jit
def p_star(xy):
    x = xy[..., 0]
    return 1.0 / 2.0 * (1.0 - jnp.exp(2.0 * l_kovaz * x))

@jit
def u_star(xy):
    x = xy[..., 0]
    y = xy[..., 1]
    u_val = 1.0 - jnp.exp(l_kovaz * x) * jnp.cos(2.0 * jnp.pi * y)
    v_val = l_kovaz / (2.0 * jnp.pi) * jnp.exp(l_kovaz * x) * jnp.sin(2.0 * jnp.pi * y)
    return jnp.stack([u_val, v_val], axis=-1)

# Model output functions (expect PyTree params)
model_u = lambda p_pytree, x_in: model_nn(p_pytree, x_in)[0:2]
model_p = lambda p_pytree, x_in: model_nn(p_pytree, x_in)[2:]

@jax.jit
def derivative_propagation(p_deriv_pytree, x_deriv):
    _ , b_list = zip(*p_deriv_pytree)
    z = x_deriv
    dz_dx = jnp.eye(len(x_deriv))
    d2z_dxx = jnp.zeros((len(x_deriv), len(x_deriv), len(x_deriv)))
    for w, b_layer in p_deriv_pytree[:-1]:
        z = jnp.dot(w, z) + b_layer
        dz_dx = jnp.dot(w, dz_dx)
        d2z_dxx = jnp.einsum('ik,kjl->ijl', w, d2z_dxx)
        sigma_prime_val = activation_prime(z)
        sigma_double_prime_val = activation_double_prime(z)
        dz_dx_old = dz_dx.copy()
        dz_dx = sigma_prime_val[:, None] * dz_dx
        term1 = sigma_double_prime_val[:, None, None] * jnp.einsum('ij,ik->ijk', dz_dx_old, dz_dx_old)
        term2 = sigma_prime_val[:, None, None] * d2z_dxx
        d2z_dxx = term1 + term2
        z = activation_orig(z)
    final_w, final_b = p_deriv_pytree[-1]
    z = jnp.dot(final_w, z) + final_b
    dz_dx = jnp.dot(final_w, dz_dx)
    d2z_dxx = jnp.einsum('ik,kjl->ijl', final_w, d2z_dxx)
    return z, dz_dx, d2z_dxx

# Residual functions (expect PyTree params)
def interior_res(p_int_res_pytree, Re_in, x_in):
    u, jac_u, d2z_dxx = derivative_propagation(p_int_res_pytree, x_in)
    grad_p_val = jac_u[2, :]
    L_u = (1.0/Re_in) * jnp.trace(d2z_dxx[:2, :2, :], axis1=1, axis2=2)
    adv_term_x = u[0] * jac_u[0,0] + u[1] * jac_u[0,1]
    adv_term_y = u[0] * jac_u[1,0] + u[1] * jac_u[1,1]
    mom_res_x = L_u[0] - adv_term_x - grad_p_val[0]
    mom_res_y = L_u[1] - adv_term_y - grad_p_val[1]
    divres = jnp.trace(jac_u[:2, :2])
    return jnp.array([divres, mom_res_x, mom_res_y])

v_interior_res = vmap(interior_res, (None, None, 0)) # First arg is PyTree params

@jit
def boundary_res(p_bound_res_pytree, x_in): # Expects PyTree params
    return model_u(p_bound_res_pytree, x_in) - u_star(x_in)

v_boundary_res_u = vmap(boundary_res, (None, 0)) # First arg is PyTree params

def rl2_norm_custom(pred_fn_outputs, true_fn_vals):
    return jnp.mean((pred_fn_outputs - true_fn_vals)**2.)**0.5 / (jnp.mean(true_fn_vals**2.)**0.5 + 1e-8)

v_u_star_eval = u_star(x_eval)
v_model_u_eval = vmap(model_u, (None, 0)) # Expects PyTree params

import os
import pickle
import timeit
import optax
import inspect
try:
    from jaxopt import LBFGS
except ImportError:
    LBFGS = None
    if args.solver == "bfgs":
        print("jaxopt is not installed. BFGS solver is unavailable.")

VERBOSE = True

import os
import pickle
import timeit
import numpy as np

# --- LOGGING SETUP ---
os.makedirs("runs/kovaz", exist_ok=True)
iterations = []
avg_relative_l2_errors = []
simulation_times = []







output_dir = f"runs/kovaz"
os.makedirs(output_dir, exist_ok=True)
base_file_name = f"{output_dir}/{args.solver}_seed_{args.seed}"

# `current_params_pytree` holds the PyTree parameters for Adam/SGD and for evaluation.
# `initial_params_flat` is for L-BFGS.
start_time_global = timeit.default_timer()
MAX_TIME_BUDGET = 3000.0
iterations_log, avg_relative_l2_errors_log, simulation_times_log, loss_values_log = [], [], [], []

initial_preds_uv = v_model_u_eval(current_params_pytree, x_eval)
initial_err_u_scalar = rl2_norm_custom(initial_preds_uv[:, 0], v_u_star_eval[:, 0])
initial_err_v_scalar = rl2_norm_custom(initial_preds_uv[:, 1], v_u_star_eval[:, 1])
initial_avg_err = 0.5 * (initial_err_u_scalar + initial_err_v_scalar)

def loss_fn_for_initial_print(p_loss_pytree, r_val, x_omega_loss, x_gamma_loss): # Expects PyTree
    res_int_vals = v_interior_res(p_loss_pytree, r_val, x_omega_loss)
    res_bnd_vals = v_boundary_res_u(p_loss_pytree, x_gamma_loss)
    loss_val = 0.5 * jnp.mean(res_int_vals**2) + 4.0 * 0.5 * jnp.mean(res_bnd_vals**2)
    return loss_val
initial_loss_val = loss_fn_for_initial_print(current_params_pytree, Re, x_Omega_init, x_Gamma_init)
if VERBOSE:
    print(f"Before training ({args.solver}): Loss: {initial_loss_val:.3e}, Avg Rel L2 Err (U,V): {initial_avg_err:.3e}")

# -----------------------------------------------------------------------------
# Adam Optimizer Loop
# -----------------------------------------------------------------------------
if args.solver == "adam":
    opt_adam = optax.adam(args.adam_lr)
    # Adam operates on PyTrees
    opt_state_adam = opt_adam.init(current_params_pytree)
    iteration_adam = 0

    @jax.jit
    def adam_step_fn_jitted(p_adam_pytree, st_adam, key_step):
        resample_key_adam, loss_grad_key_adam = random.split(key_step)
        current_x_Omega_adam, current_x_Gamma_adam = resample_hyperrectangle_local(
            resample_key_adam, box, N_interior_train, N_boundary_train
        )
        # Loss function for Adam (expects PyTree params)
        def loss_adam(params_loss_pytree):
            res_int_adam = v_interior_res(params_loss_pytree, Re, current_x_Omega_adam)
            res_bnd_adam = v_boundary_res_u(params_loss_pytree, current_x_Gamma_adam)
            return 0.5 * jnp.mean(res_int_adam**2) + 4.0 * 0.5 * jnp.mean(res_bnd_adam**2)

        loss_val_adam, grads_adam_pytree = value_and_grad(loss_adam)(p_adam_pytree)
        updates_adam_pytree, st_adam_new = opt_adam.update(grads_adam_pytree, st_adam, p_adam_pytree)
        p_adam_pytree_new = optax.apply_updates(p_adam_pytree, updates_adam_pytree)
        return p_adam_pytree_new, st_adam_new, loss_val_adam, loss_grad_key_adam

    while True:
        total_elapsed_global = timeit.default_timer() - start_time_global
        if total_elapsed_global > MAX_TIME_BUDGET:
            print(f"Adam: Time budget exceeded ({total_elapsed_global:.1f}s / {iteration_adam} iter), stopping.")
            break
        step_start_time = timeit.default_timer()
        current_params_pytree, opt_state_adam, current_loss_scalar, key = adam_step_fn_jitted(
            current_params_pytree, opt_state_adam, key
        )
        step_end_time = timeit.default_timer()
        elapsed_step = step_end_time - step_start_time

        preds_uv = v_model_u_eval(current_params_pytree, x_eval) # Eval expects PyTree
        err_u_scalar = rl2_norm_custom(preds_uv[:, 0], v_u_star_eval[:, 0])
        err_v_scalar = rl2_norm_custom(preds_uv[:, 1], v_u_star_eval[:, 1])
        avg_err_scalar = 0.5 * (err_u_scalar + err_v_scalar)
        iterations_log.append(iteration_adam)
        loss_values_log.append(float(current_loss_scalar))
        avg_relative_l2_errors_log.append(float(avg_err_scalar))
        simulation_times_log.append(elapsed_step)
        if VERBOSE and iteration_adam % 50 == 0:
            print(
                f"Adam Iter {iteration_adam} | Loss: {current_loss_scalar:.3e} | Avg Rel L2 Err (U,V): {avg_err_scalar:.3e} | "
                f"Step time: {elapsed_step:.2f}s | Total time: {total_elapsed_global:.2f}s"
            )
        iteration_adam += 1

# -----------------------------------------------------------------------------
# SGD Optimizer Loop
# -----------------------------------------------------------------------------
elif args.solver == "sgd":
    schedule_sgd = optax.join_schedules(
        [optax.linear_schedule(0., args.initial_lr, args.warmup_steps),
         optax.exponential_decay(args.initial_lr, args.decay_steps, args.decay_rate)],
        [args.warmup_steps]
    )
    opt_sgd = optax.chain(
        optax.clip_by_global_norm(args.max_grad_norm),
        optax.sgd(learning_rate=schedule_sgd, momentum=args.momentum, nesterov=False)
    )
    opt_state_sgd = opt_sgd.init(current_params_pytree) # Operates on PyTrees
    iteration_sgd = 0

    @jax.jit
    def sgd_step_fn_jitted(p_sgd_pytree, st_sgd, key_step):
        resample_key_sgd, loss_grad_key_sgd = random.split(key_step)
        current_x_Omega_sgd, current_x_Gamma_sgd = resample_hyperrectangle_local(
            resample_key_sgd, box, N_interior_train, N_boundary_train
        )
        def loss_sgd(params_loss_pytree): # Expects PyTree
            res_int_sgd = v_interior_res(params_loss_pytree, Re, current_x_Omega_sgd)
            res_bnd_sgd = v_boundary_res_u(params_loss_pytree, current_x_Gamma_sgd)
            return 0.5 * jnp.mean(res_int_sgd**2) + 4.0 * 0.5 * jnp.mean(res_bnd_sgd**2)

        loss_val_sgd, grads_sgd_pytree = value_and_grad(loss_sgd)(p_sgd_pytree)
        updates_sgd_pytree, st_sgd_new = opt_sgd.update(grads_sgd_pytree, st_sgd, p_sgd_pytree)
        p_sgd_pytree_new = optax.apply_updates(p_sgd_pytree, updates_sgd_pytree)
        return p_sgd_pytree_new, st_sgd_new, loss_val_sgd, loss_grad_key_sgd

    while True:
        total_elapsed_global = timeit.default_timer() - start_time_global
        if total_elapsed_global > MAX_TIME_BUDGET:
            print(f"SGD: Time budget exceeded ({total_elapsed_global:.1f}s / {iteration_sgd} iter), stopping.")
            break
        step_start_time = timeit.default_timer()
        current_params_pytree, opt_state_sgd, current_loss_scalar, key = sgd_step_fn_jitted(
            current_params_pytree, opt_state_sgd, key
        )
        step_end_time = timeit.default_timer()
        elapsed_step = step_end_time - step_start_time
        
        preds_uv = v_model_u_eval(current_params_pytree, x_eval) # Eval expects PyTree
        err_u_scalar = rl2_norm_custom(preds_uv[:, 0], v_u_star_eval[:, 0])
        err_v_scalar = rl2_norm_custom(preds_uv[:, 1], v_u_star_eval[:, 1])
        avg_err_scalar = 0.5 * (err_u_scalar + err_v_scalar)
        iterations_log.append(iteration_sgd)
        loss_values_log.append(float(current_loss_scalar))
        avg_relative_l2_errors_log.append(float(avg_err_scalar))
        simulation_times_log.append(elapsed_step)
        if VERBOSE and iteration_sgd % 50 == 0:
            print(
                f"SGD Iter {iteration_sgd} | Loss: {current_loss_scalar:.3e} | Avg Rel L2 Err (U,V): {avg_err_scalar:.3e} | "
                f"Step time: {elapsed_step:.2f}s | Total time: {total_elapsed_global:.2f}s"
            )
        iteration_sgd += 1

# -----------------------------------------------------------------------------
# L-BFGS Optimizer Loop (jaxopt.LBFGS with flat parameters)
# -----------------------------------------------------------------------------
elif args.solver == "bfgs":
    if LBFGS is None:
        print("jaxopt.LBFGS is required for BFGS solver but not found. Skipping training.")
        iterations_log.append(0); loss_values_log.append(float(initial_loss_val))
        avg_relative_l2_errors_log.append(float(initial_avg_err)); simulation_times_log.append(0.0)
    else:
        iteration_bfgs = 0
        current_params_flat = initial_params_flat # Start with the initial flat parameters

        # Loss function for L-BFGS: takes flat params, returns scalar loss.
        # `unravel` is captured from the global scope.
        def loss_for_lbfgs(p_flat, key_for_resampling): # key_for_resampling is passed by solver.update
            p_pytree = unravel(p_flat) # Unflatten params for model functions
            current_x_Omega_bfgs, current_x_Gamma_bfgs = resample_hyperrectangle_local(
                key_for_resampling, box, N_interior_train, N_boundary_train
            )
            res_int_bfgs = v_interior_res(p_pytree, Re, current_x_Omega_bfgs)
            res_bnd_bfgs = v_boundary_res_u(p_pytree, current_x_Gamma_bfgs) # Typo: x_Gamma_bfgs_current_epoch -> current_x_Gamma_bfgs
            loss_val = 0.5 * jnp.mean(res_int_bfgs**2) + 4.0 * 0.5 * jnp.mean(res_bnd_bfgs**2)
            return jnp.asarray(loss_val, dtype=jnp.float64)

        # Initialize LBFGS solver. value_and_grad=False: LBFGS computes grad of fun.
        solver_lbfgs = LBFGS(fun=loss_for_lbfgs, maxiter=args.bfgs_maxiter, value_and_grad=False, jit=True, unroll=False)
        # Initial state: needs initial flat params and any other args `fun` expects (like the key)
        opt_state_lbfgs = solver_lbfgs.init_state(current_params_flat, key_for_resampling=key)

        @jit
        def bfgs_step_fn_jitted(p_flat_step, st_bfgs, key_for_loss_and_next_iter):
            # Pass the key to `update` which then passes it to `loss_for_lbfgs`
            p_flat_new, st_new = solver_lbfgs.update(
                p_flat_step, st_bfgs, key_for_resampling=key_for_loss_and_next_iter
            )
            return p_flat_new, st_new, st_new.value, key_for_loss_and_next_iter

        while True:
            total_elapsed_global = timeit.default_timer() - start_time_global
            if total_elapsed_global > MAX_TIME_BUDGET:
                print(f"L-BFGS: Time budget exceeded ({total_elapsed_global:.1f}s / {iteration_bfgs} steps), stopping.")
                break
            step_start_time = timeit.default_timer()
            try:
                current_params_flat, opt_state_lbfgs, current_loss_scalar, key = bfgs_step_fn_jitted(
                    current_params_flat, opt_state_lbfgs, key
                )
            except Exception as e:
                 print(f"Error during L-BFGS step: {e}")
                 if "iteration over a 0-d array" in str(e) or "nan" in str(e).lower() or "inf" in str(e).lower():
                     print("L-BFGS encountered numerical issues (NaN/Inf/0-d iteration). Stopping.")
                 break
            step_end_time = timeit.default_timer()
            elapsed_step = step_end_time - step_start_time

            current_params_pytree = unravel(current_params_flat) # Unflatten for evaluation
            preds_uv = v_model_u_eval(current_params_pytree, x_eval)
            err_u_scalar = rl2_norm_custom(preds_uv[:, 0], v_u_star_eval[:, 0])
            err_v_scalar = rl2_norm_custom(preds_uv[:, 1], v_u_star_eval[:, 1])
            avg_err_scalar = 0.5 * (err_u_scalar + err_v_scalar)
            iterations_log.append(iteration_bfgs)
            loss_values_log.append(float(current_loss_scalar))
            avg_relative_l2_errors_log.append(float(avg_err_scalar))
            simulation_times_log.append(elapsed_step)
            if VERBOSE and iteration_bfgs % 10 == 0:
                print(
                    f"L-BFGS Step {iteration_bfgs} (inner_maxiter: {args.bfgs_maxiter}) | Loss: {current_loss_scalar:.3e} | Avg Rel L2 Err (U,V): {avg_err_scalar:.3e} | "
                    f"Step time: {elapsed_step:.2f}s | Total time: {total_elapsed_global:.2f}s"
                )
            iteration_bfgs += 1
else:
    if args.solver not in ["adam", "sgd", "bfgs"]:
         print(f"Unknown or unhandled solver: {args.solver}")

# --- Save Results ---

# Ensure final_params saved is in PyTree format for consistency
final_params_to_save = unravel(current_params_flat) if args.solver == "bfgs" else current_params_pytree

final_results = {
    "seed": args.seed,
    "solver": args.solver,
    "iterations_steps": iterations_log,
    "losses": loss_values_log,
    "avg_relative_l2_errors": avg_relative_l2_errors_log,
    "simulation_step_times": simulation_times_log,

    "args": vars(args)
}
results_file_path = f"{base_file_name}_results.pkl"
with open(results_file_path, "wb") as f:
    pickle.dump(final_results, f)
print(f"All results saved to {results_file_path}")