"""
HFNGD
10 dimensional Poisson equation example. Solution given by
  u*(x) = sum(x_{2k-1} * x_{2k})

"""
import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit

from ngrad.domains import Hyperrectangle, HypercubeBoundary
from ngrad.models import mlp, init_params
from ngrad.integrators import EvolutionaryIntegrator
from ngrad.utility import laplace, grid_line_search_factory
from ngrad.inner import model_laplace, model_identity
from ngrad.gram import gram_factory, nat_grad_factory

import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
	"--LM", 
	help="Levenberg-Marquardt regularization", 
	default=1,
	type=float,
	)
parser.add_argument(
	"--tresh", 
	help="treshold on the frobenius norm", 
	default=1e4,
	type=float,
	)
parser.add_argument(
	"--iter", 
    help="number of iterations", 
	default=500, 
	type=int,
	)
parser.add_argument(
	"--maxiter", 
    help="number of iterations", 
	default=250, 
	type=int,
	)
parser.add_argument(
	"--budget", 
    help="number of iterations", 
	default=3000, 
	type=float,
	)
parser.add_argument(
	"--tau", 
    help="number of iterations", 
	default=250, 
	type=int,
	)

parser.add_argument(
	"--minb", 
    help="number of iterations", 
	default=50, 
	type=int,
	)
parser.add_argument(
	"--ssiter", 
    help="number of iterations", 
	default=1, 
	type=int,
	)

parser.add_argument(
	"--seed", 
    help="number of iterations", 
	default=0, 
	type=int,
	)
args = parser.parse_args()
jax.config.update("jax_enable_x64", True)

# random seed
seed = args.seed

# domains
dim = 10
interior = Hyperrectangle([(0., 1.) for _ in range(0, dim)])
boundary = HypercubeBoundary(dim)

# integrators
interior_integrator = EvolutionaryIntegrator(interior, key=random.PRNGKey(0), N=1000)
boundary_integrator = EvolutionaryIntegrator(boundary, key= random.PRNGKey(1), N=1200)
eval_integrator = EvolutionaryIntegrator(interior, key=random.PRNGKey(0), N=  4000)

ITER=500

tau=args.tau




#%%
# model

import time
activation = lambda x : jnp.tanh(x)

layer_sizes = [dim, 512,512,512,512,512, 1]

layer_sizes = [dim, 64, 1] # For demosntration on laptops
import jax
import jax.numpy as jnp
from jax import random

# Glorot (Xavier) initialization for layer parameters
def glorot_layer_params(m: int, n: int, key):
    w_key, b_key = random.split(key)
    # Calculate the standard deviation for Glorot normal initialization
    stddev = jnp.sqrt(2.0 / (m + n))
    w = stddev * random.normal(w_key, (n, m))
    b = stddev * random.normal(b_key, (n,))
    return w, b

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [glorot_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

params = init_params(layer_sizes, random.PRNGKey(seed))
model = mlp(activation)
v_model = vmap(model, (None, 0))



def u_star(x):
    """
    Computes u*(x) for d = 10, where
    u*(x) = sum(x_{2k-1} * x_{2k}) for k = 1, ..., 10 using slicing.
    
    :param x: 10-dimensional input vector
    :return: u*(x)
    """
    # Use slicing to get odd and even indexed elements
    odd_indices = x[::2]   # x_1, x_3, x_5, x_7, x_9 (0-based index)
    even_indices = x[1::2] # x_2, x_4, x_6, x_8, x_10 (0-based index)

    # Compute the sum of products of odd and even indexed elements
    result = jnp.sum(odd_indices * even_indices)
    
    return result



v_u_star = vmap(u_star, (0))
v_grad_u_star = vmap(
    lambda x: jnp.dot(grad(u_star)(x), grad(u_star)(x))**0.5, (0)
    )

# rhs

# gramians


from new_natgrad.gram import gram_factory, gvp_factory

from jax.flatten_util import ravel_pytree




f_params,unravel=ravel_pytree(params)
print(f_params.shape)

@jit
def interior_res(params,x):
    return model(params, x) #- u_star(x)


laplace_model = lambda params: laplace(lambda x: model(params, x))
residual = lambda params, x: (laplace_model(params)(x) + f(x))**2.

@jit
def gramian_res(params,x):
    return laplace_model(params)(x)


#%%
def tanh(x):
    return jnp.tanh(x)

def tanh_prime(x):
    return 1.0 - jnp.tanh(x)**2

def tanh_double_prime(x):
    return -2.0 * jnp.tanh(x) * (1.0 - jnp.tanh(x)**2)

activation = tanh
activation_prime = tanh_prime
activation_double_prime = tanh_double_prime
def Laplace(params, x):
    W, b = zip(*params)
    z = x
    """
    Component-wise propagation of first-order and second-order derivatives for Laplacian calculation
    for arbitrary dimensions. Each derivative component retains the same shape as the input vector.

    :param params: List of network parameters (weights and biases)
    :param x: Input vector of arbitrary dimension
    :param activation: Activation function
    :param activation_prime: Derivative of activation function
    :param activation_double_prime: Second derivative of activation function
    :return: Output, first-order derivatives, and Laplacian (second-order diagonal terms)
    """
    n = x.shape[0]  # Dimensionality of the input

    # Initialize first-order derivatives (each component of dz_dx has the shape of x)
    dz_dx = [jnp.zeros_like(x) for _ in range(n)]
    for i in range(n):
        dz_dx[i] = dz_dx[i].at[i].set(1.0)  # Set initial derivatives to reflect partial w.r.t. each component

    # Initialize second-order diagonal derivatives (each d2z_dxx[i] has the shape of x)
    d2z_dxx = [jnp.zeros_like(x) for _ in range(n)]

    # Iterate through each layer (except the last one)
    for w, b in params[:-1]:
        # Linear transformation
        z = jnp.dot(w, z) + b

        # Propagate first-order derivatives (vectorized for each dimension)
        dz_dx = [jnp.dot(w, dz_dx_i) for dz_dx_i in dz_dx]

        # Propagate second-order diagonal terms (Laplacian components)
        d2z_dxx = [jnp.dot(w, d2z_dxx_i) for d2z_dxx_i in d2z_dxx]

        # Non-linear activation derivatives
        sigma_prime = activation_prime(z)
        sigma_double_prime = activation_double_prime(z)

        # Save previous first-order derivatives for chain rule
        dz_dx_old = dz_dx.copy()
        d2z_dxx_old = d2z_dxx.copy()

        # First-order derivative update (component-wise)
        dz_dx = [sigma_prime * dz_dx_i for dz_dx_i in dz_dx_old]

        # Second-order derivative update (Laplacian diagonal terms only, no mixed terms)
        d2z_dxx = [sigma_double_prime * dz_dx_old_i ** 2 + sigma_prime * d2z_dxx_old_i
                   for dz_dx_old_i, d2z_dxx_old_i in zip(dz_dx_old, d2z_dxx_old)]

        # Apply non-linear activation
        z = activation(z)

    # Final linear layer propagation (no activation)
    final_w, final_b = params[-1]
    z = jnp.dot(final_w, z) + final_b
    dz_dx = [jnp.dot(final_w, dz_dx_i) for dz_dx_i in dz_dx]
    d2z_dxx = [jnp.dot(final_w, d2z_dxx_i) for d2z_dxx_i in d2z_dxx]

    # The Laplacian is the sum of the diagonal second-order derivatives (Laplacian = trace of second derivatives)
    # laplacian = sum(jnp.trace(d2z_dxx_i) for d2z_dxx_i in d2z_dxx)

    return jnp.sum(jnp.array(d2z_dxx))

#%%
@jit
def interior_res(params,x):
    return Laplace(params, x)#- u_star(x)
@jit
def gramian_res(params,x):
    return Laplace(params, x)#- u_star(x)


# loss
@jit
def interior_loss(params):
    return 0.5 * interior_integrator(lambda x: v_residual(params, x))
# @jit
def boundary_res(params,x):
    return model(params, x) - u_star(x)


v_residual=  jit(vmap(interior_res, (None, 0)))
v_boundary_res=  jit(vmap(boundary_res, (None, 0)))



@jit
def loss(params):
    
    x_Omega=interior_integrator._x
    x_Gamma=boundary_integrator._x
    loss=0.5*jnp.mean((v_residual(params,x_Omega))**2)+0.5*jnp.mean((v_boundary_res(params,x_Gamma))**2)
    
    
    
    
    return loss

loss(params)



#%%
LM=args.LM
# model_u = lambda params, x: 0.1*model(params, x)*(x[1]**2-1)-1

f_params, unravel = ravel_pytree(params)
v=jnp.zeros_like(ravel_pytree(params)[0])+1
def build_A(v,params):
    
    x_Omega=interior_integrator._x
    x_Gamma=boundary_integrator._x
    
     
    apply_int = gvp_factory(gramian_res, params, x_Omega)
    apply_bdry = gvp_factory(boundary_res, params, x_Gamma)
    
    # define action with LM regularization
    mu = jnp.min(jnp.array([loss(params), LM]))



    return  apply_int(v) +  apply_bdry(v) + mu*v  

build_A(v,params)


#%%
# set up grid line search
grid = jnp.linspace(0, 30, 31)
steps = 0.5**grid
ls_update = grid_line_search_factory(loss, steps)

# errors
error = lambda x: model(params, x) - u_star(x)
v_error = vmap(error, (0))
v_error_abs_grad = vmap(
        lambda x: jnp.dot(grad(error)(x), grad(error)(x))**0.5
        )

def l2_norm(f, integrator):
    return integrator(lambda x: (f(x))**2)**0.5

norm_sol_l2 = l2_norm(v_u_star, eval_integrator)
norm_sol_h1 = norm_sol_l2 + l2_norm(v_grad_u_star, eval_integrator)    








grid = jnp.linspace(0, 30, 31)
steps = 0.5**grid
ls_update = grid_line_search_factory(loss, steps)


#%%
from jax.tree_util import Partial

A=Partial(build_A,params=params)
f_params, unravel = ravel_pytree(params)
n=f_params.shape[0]

# a3=jnp.linalg.eigh(apply_A_to_canonical_basis(A, n))[0][::-1]
tau=args.tau
taup=n-1

from matfree import decomp
from lanczos import *


def _build_arrowhead_matrix(alpha, beta, k):
  """Build the "arrowhead" shaped matrix needed for thick-restart Lanczos."""
  m = alpha.size
  upper_triangle = (
      jnp.einsum('i,j', beta * (k > jnp.arange(m)), k == jnp.arange(m))
      + jnp.diag(beta[:-1] * (k <= jnp.arange(m - 1)), k=1)
  )
  return upper_triangle + upper_triangle.T + jnp.diag(alpha)

# from matfree.backend import linalg, np, testing
start_time = timeit.default_timer()

algorithm = decomp.tridiag_sym(tau, reortho='full', materialize=False,custom_vjp=False)

def create_tridiagonal_matrix(alpha, beta):
    # Ensure alpha and beta have compatible dimensions
    n = len(alpha)
    T = jnp.diag(alpha)  # Main diagonal from alpha
    if len(beta) > 0:  # Check if beta is non-empty
        T += jnp.diag(beta, k=1)  # Superdiagonal
        T += jnp.diag(beta, k=-1)  # Subdiagonal
    return T
from jax.numpy.linalg import svd
from jax.scipy.linalg import qr
from jax.tree_util import Partial

def orthogonal_decomposition_L(params,tau):
    """
    Performs an eigendecomposition on the result of the Lanczos algorithm and extracts the top tau eigenvectors and eigenvalues.

    :param build_A: Function to build the matrix A.
    :param params: Parameters for building matrix A.
    :param tau: The number of eigenvectors and eigenvalues to extract.
    :param v: The initial vector for the Lanczos algorithm.
    :return: High-dimensional eigenvectors V and eigenvalues ak.
    """
    
    # Set up the matrix-vector multiplication function A using partial
    A=Partial(build_A,params=params)
    
    # Run the Lanczos algorithm to get the tridiagonal matrix T and the matrix V_k
    L,_ = algorithm(A, v)
    V_k, (alpha, beta) = L 
    
    # Build the tridiagonal matrix T from alpha and beta
    # T = _build_arrowhead_matrix(alpha, beta, tau)
    T = create_tridiagonal_matrix(alpha, beta)
    
     
    # U, Sigma, VT = jax_randomized_svd_jit(T, tau+1)
    U, Sigma, VT = svd(T,hermitian=True,full_matrices=False)
    
    # Perform the eigendecomposition on the tridiagonal matrix T using jnp.linalg.eigh
    # eigenvalues, eigenvectors_T_space = eigh(T)
    # selected_eigenvalues = select_first_tau_eigenvalues(eigenvalues, tau)
    # Project the eigenvectors of T back to the high-dimensional space
    high_dim_eigenvectors = V_k.T @ U
    
    return high_dim_eigenvectors, Sigma
orthogonal_decomposition_L = jit(orthogonal_decomposition_L, static_argnums=(1,))

# _,a2=orthogonal_decomposition_L(params,n-1)
# print(jnp.mean((a2-a3[:tau])**2)/jnp.mean(a2**2))
#%%

from jax.tree_util import Partial
from new_natgrad.cg_utils import _vdot_real_tree,_add,_mul,_sub,tree_leaves
@Partial
def _identity(x):
  return x
@Partial
def _identity(x):
  return x

def _cg_solve(A, b, x0=None, *, maxiter=10000, tol=1e-5, atol=0.0, M=_identity):

  # tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
  bs = _vdot_real_tree(b, b)
  atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))

  # https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method

  def cond_fun(value):
    _, r, gamma, _, k = value
    rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
    return (rs > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / _vdot_real_tree(p, Ap).astype(dtype)
    x_ = _add(x, _mul(alpha, p))
    r_ = _sub(r, _mul(alpha, Ap))

    z_ = M(r_)


    gamma_ = _vdot_real_tree(r_, z_).astype(dtype)
    beta_ = gamma_ / gamma
    p_ = _add(z_, _mul(beta_, p))
    return x_, r_, gamma_, p_, k + 1

  r0 = _sub(b, A(x0))
  p0 = z0 = M(r0)
  dtype = jnp.result_type(*tree_leaves(p0))
  gamma0 = _vdot_real_tree(r0, z0).astype(dtype)
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final,  r_, gamma_, p_, k  = jax.lax.while_loop(cond_fun, body_fun, initial_value)

  return x_final,k



#%%

from jax import jit, lax
import jax.numpy as jnp

# JIT compile the function ensuring that actual_size doesn't cause recompilation
# woodbury_action_matrix_free_jit = woodbury_action_matrix_free#jit(woodbury_action_matrix_free)
from jax import jit, numpy as jnp


@jit
def woodbury_action_matrix_free(b, alpha_k, V_k, effective_middle_term):
    """
    Matrix-free method for applying the Woodbury matrix identity inversion, modified to handle dynamic sizes efficiently.

    :param b: Vector b to be multiplied by the inverse matrix.
    :param alpha_k: Scalar alpha_k.
    :param V_k: Matrix V_k, already padded to max_size.
    :param ak: Vector of diagonal elements ak, already padded to max_size.
    :param actual_size: The actual number of active components.
    :param max_size: The fixed maximum size of V_k and ak.
    :return: The product of the inverse matrix and vector b.
    """
    V_k_T_b = jnp.dot(V_k.T, b)


    return (1.0 / alpha_k) * (b - jnp.dot(V_k, effective_middle_term * V_k_T_b))


#%%
# tau=50
class MatrixFreeSolver:
    def __init__(self, build_A, woodbury_action_matrix_free):
        # Store function references, not parameters
        self.build_A = build_A
        self.woodbury_action_matrix_free = woodbury_action_matrix_free
        
        # JIT compile the solver method
        self._solve = jit(self._solve_impl)

    def _solve_impl(self, params, alpha_k,  V_k, effective_middle_term, ak, b, x0, maxiter, tol):
        # Create A and M using the current parameters passed as arguments
        A = lambda x: self.build_A(x,params)
        
        # A=Partial(build_A,x_branch=x_branch,params=params)
        M = lambda x: self.woodbury_action_matrix_free(x,alpha_k=alpha_k,V_k=V_k,effective_middle_term=effective_middle_term)
        
        
        
        return _cg_solve(A=A, b=b, x0=M(b),maxiter=maxiter, tol=tol,M=M)

    def solve(self, params, alpha_k,  V_k, effective_middle_term, ak, b, x0, maxiter, tol):
        # The public interface to perform the solve uses current parameters
        return self._solve(params, alpha_k,  V_k, effective_middle_term, ak, b, x0, maxiter, tol)
    
    





    

#%%
solver = MatrixFreeSolver(build_A, woodbury_action_matrix_free)



#%%









#%%

from jax.tree_util import Partial




A22=jax.jit(Partial(jax.vmap(Partial(build_A, params=params),(0,None))))


#%%

def split_A22(params,V, chunk_size):
    def process_chunk(carry, idx):
        chunk = V[:, idx*chunk_size:(idx+1)*chunk_size]
        result = vmap(partial(build_A, params=params))(chunk.T)
        return carry, result
    
    n_chunks = V.shape[1] // chunk_size
    results = []
    carry = None
    for i in range(n_chunks):
        carry, result = process_chunk(carry, i)
        results.append(result)
    
    # Handle remaining columns if V.shape[1] is not a multiple of chunk_size
    if V.shape[1] % chunk_size != 0:
        remaining_chunk = V[:, n_chunks*chunk_size:]
        carry, result = process_chunk(carry, n_chunks)
        results.append(result)
        
        
        
    
    return jnp.concatenate(results, axis=0)
#%%


from jax.tree_util import Partial


import jax
import jax.numpy as jnp
from jax.numpy.linalg import norm, qr

def subspace_iteration(params, V, num_iterations=1, tolerance=1e-6):
    """
    Perform subspace iteration with a tolerance check to approximate eigenvectors.
    
    Parameters:
        params (jax.tree_util.PyTree): Parameters required for matrix-vector multiplication.
        V (jax.numpy.ndarray): Initial matrix of eigenvector estimates.
        num_iterations (int): Number of subspace iterations to perform.
        tolerance (float): Stopping criterion based on the change in eigenvector approximation.
    
    Returns:
        V (jax.numpy.ndarray): Updated matrix of eigenvector estimates after iteration.
    """
    for i in range(num_iterations):
        # Perform matrix-vector product
        W = split_A22(params, V, args.minb).T
        
        # QR factorization to re-orthonormalize the vectors
        V_new, _ = qr(W)
        
        # Check the convergence condition ||V_new - V||_2 <= tolerance
        # if norm(V_new - V, ord=2) <= tolerance:
        #     break  # Stop iteration if the change is within tolerance
        
        V = V_new  # Update V for the next iteration
    
    return V

#%%

@jit
def build_Aeig(v,params):

    
    x_Omega=interior_integrator._x
    x_Gamma=boundary_integrator._x
    
     
    apply_int = gvp_factory(residual, params, x_Omega)
    apply_bdry = gvp_factory(boundary_res, params, x_Gamma)
    
    # define action with LM regularization
    mu = jnp.min(jnp.array([loss(params), LM]))


    # define action with LM regularization


    q=jnp.dot(v.T, apply_int(v) + apply_bdry(v) + mu*v ) / jnp.dot(v.T, v)

    return  q
            
# build_Aeig(V_k0[:,0],params2)



batch_rayleigh_quotient=jit(vmap(build_Aeig,(0,None)))






import json


def save_checkpoint(params, loss_value, l2_error, h1_error, step, iteration, kit, file_name="checkpoints.json"):
    # Create a dictionary to hold checkpoint data
    checkpoint_data = {
        # 'timestamp': str(datetime.now()),
        'iteration': int(iteration),
        'loss': float(loss_value),
        'l2_error': float(l2_error),
        'h1_error': float(h1_error),
        'step': float(step),
        'cg_iteration': kit
    }

 

    # Save updated checkpoint data to the file
    with open(file_name, 'w') as f:
        json.dump(checkpoint_data, f, indent=4)

#%%



V_k0, ak0 = orthogonal_decomposition_L(params, tau)





#%%
import datetime
import timeit
import json
from lanczos import *
from jax import grad, vmap, jit
import jax.numpy as jnp
from jax.tree_util import Partial
from jax.flatten_util import ravel_pytree

actual_size = 10000

Train = True
its = []
l2_errors = []
kit = 0
data = {
    'seed': seed,  # Ensure seed is defined or replace with the seed used in your simulation
    'loss': [],
    'l2_errors': [],
    'iterations': [],
    'optim_time': []
}

# Set the time budget in seconds
time_budget = args.budget  # Run for 3000 seconds (e.g., 50 minutes)
start_time0 = timeit.default_timer()
elapsed_time = 0
iteration = 0

if Train:
    # Initial guess for cg solve
    f_lm_grad = jnp.zeros_like(ravel_pytree(params)[0])
    
    # Levenberg Marquard matrix free with line search
    top_eigenvalues = []
    steps = []
    
    while elapsed_time < time_budget:
        start_time = timeit.default_timer()

        f_grads, unravel = ravel_pytree(grad(loss)(params))

        A = Partial(build_A, params=params)
        mu = jnp.min(jnp.array([loss(params), LM]))
        # residual_norm = jnp.linalg.norm(A (f_lm_grad) - f_grads)
        
        
        
        residual_norm = jnp.linalg.norm(split_A22(params, V_k0, args.minb) - V_k0 @ ak0, ord='fro')
       
        if (iteration % 10 == 0 or residual_norm>args.tresh):
            
            
            
            if jnp.min(ak0) > 1000 * mu:
                tau += 50  # Update the rank increment

            alpha_k = mu 
            V_k0, ak0 = orthogonal_decomposition_L(params, tau)

        else:
            V_k0 = subspace_iteration(params, V_k0)

        V_k, ak = V_k0[:, ak0 > mu], ak0[ak0 > mu]
        actual_size = V_k.shape[1]
        V_k = jnp.pad(V_k, ((0, 0), (0, tau - V_k.shape[1])), constant_values=0)
        ak = jnp.pad(ak, (0, tau - ak.shape[0]), constant_values=0)

        active_mask = (jnp.arange(tau) < actual_size).astype(float)

        # Apply the mask
        V_k_active = V_k

        # Mask ak to safely handle division
        ak_adjusted = jnp.where(active_mask, ak, 1e-10)  # Replace inactive parts with a small number to prevent div by zero

        # Compute the middle term inversion
        middle_term_inv = 1 / (jnp.ones(tau) + alpha_k * (1.0 / ak_adjusted))

        # Mask the middle term to be effective only for active parts
        effective_middle_term = middle_term_inv * active_mask

        f_lm_grad, kit = solver.solve(params=params, alpha_k=alpha_k, V_k=V_k, ak=ak, b=f_grads, x0=f_lm_grad, maxiter=args.maxiter, tol=mu, effective_middle_term=effective_middle_term)

        lm_grad = unravel(f_lm_grad)

        # Linesearch on logarithmic grid
        params, actual_step = ls_update(params, lm_grad)
       
        if iteration % 1 == 0:
            l2_error = l2_norm(v_error, eval_integrator)
            h1_error = l2_error + l2_norm(v_error_abs_grad, eval_integrator)
            
            # Save data to the dictionary
            data['loss'].append(float(loss(params)))
            data['l2_errors'].append(float(l2_error))
            data['iterations'].append(int(kit))
            data['optim_time'].append(float(timeit.default_timer() - start_time0))

            print(
                f'NG Iteration: {iteration} with loss: {loss(params)} with error '
                f'L2: {l2_error} and error H1: {h1_error} and '
                f'step: {actual_step} '
                f'iterationtime: {float(timeit.default_timer() - start_time0)} '
                f'cgit: {kit} '
                f'run time {timeit.default_timer() - start_time0}'
            )

        # Save progress every 100 iterations or if it's the last iteration due to time budget
        if iteration % 100 == 0 or float(timeit.default_timer() - start_time0) >= time_budget:
            timestamp = time.strftime("%Y%m%d-%H%M%S")
            # filename = f'poisson10d_{args.seed}.json'
            # with open(filename, 'w') as f:
            #     json.dump(data, f, indent=4)
            # print(f"Data saved to {filename}")

        # Update elapsed time and iteration count
        elapsed_time = timeit.default_timer() - start_time0
        iteration += 1

print(f"Total run time: {elapsed_time} seconds")


final_filename = f'poisson10dMIL_seed_{args.seed}_budget_{args.budget}_maxiter_{args.maxiter}.json'
with open(final_filename, 'w') as f:
    json.dump(data, f, indent=4)

print(f"Data saved to {final_filename}")         

import json
import timeit
from datetime import datetime





