import os
import jax 
import jax.numpy as jnp
import jax.random as random
from typing import Sequence, Union, Callable
import optax
import scipy.io 
from flax.training import train_state, orbax_utils
from layer import cayley, Unitary, QuadPotential, LipNonlin, LipSwish
import orbax.checkpoint
from rosenbrock_utils import *
from flax import linen as nn 
import numpy as np
import math
import jax.numpy as jnp
from jax import vmap
import random

#Code is heavily adapted from https://github.com/acfr/PLNet/blob/main/surrogte_loss/rb20d_train.py

# Generate a random number
file_name = str(random.randint(1000, 9999))

print("FILE NAME", file_name)

import jax.numpy as jnp
import jax.random as random

bound = 100.

generated_data_fraction = 0.5

divide_by_10 = True

print('generated data fraction', generated_data_fraction)

print("BOUND", bound)

global rng
rng = random.PRNGKey(0)

new_function = schaffer_7

print("FUNCTION", new_function)

orig_function = vmap(new_function, in_axes=0)

clip_norm_value = 0.5

print('CLIP NORM VALUE', clip_norm_value)

data_dim = 1000

print('data dim', data_dim)

rng_train = random.split(rng, 1)
xtrain = Sampler(rng_train, 
                     20000,
                     data_dim, 
                     x_min=-bound, x_max=bound)

temp_batch_size = 2000
num_batches = int(np.ceil(xtrain.shape[0] / temp_batch_size))
results = []    
for i in range(num_batches):
    batch_x = xtrain[i * temp_batch_size : (i + 1) * temp_batch_size]
    batch_y = jnp.squeeze(orig_function(batch_x))
    results.append(batch_y)
results = jnp.concatenate(results, axis=0)

mean_y = jnp.mean(results)
std_y = jnp.std(results)

if divide_by_10:
    std_y = 10*std_y
    print('multiplying std by 10')

print('mean and std', mean_y, std_y)


def testing_function(x):
    # Apply vmap over your input x
    mapped_output = vmap(new_function, in_axes=0)(x)
    
    # Normalize using mean_y and std_y
    normalized_output = (mapped_output - mean_y) / std_y
    
    return normalized_output

interpolation_threshold = 0.1*(data_dim/100)*(bound/2)*(5**0.5) 

threshold_epoch = 200

large_threshold = (threshold_epoch > 3000)

print("THRESHOLD EPOCH", threshold_epoch)

threshold_value = (((data_dim/100)*((1)**2))**(0.5))

print('decent threshold value', threshold_value)


import jax
import jax.numpy as jnp

def get_interpolated_tensors(tensor):
        B, H = tensor.shape
        results = []
        i = 0
        while i < B:
            current_tensor = tensor[i]
            next_tensor = tensor[i + 1]
            distance = jnp.linalg.norm(next_tensor - current_tensor)

            if distance >= interpolation_threshold:
                # Add increments of 0.1 between the tensors
                n_steps = int(jnp.floor(distance / interpolation_threshold))
                steps = jnp.linspace(0, 1, n_steps + 2)[1:-1]  # avoid including the endpoints
                interpolated_tensors = jnp.outer(steps, next_tensor - current_tensor) + current_tensor
                results.append(current_tensor)
                results.extend(interpolated_tensors)
                i+=1
            elif distance < interpolation_threshold:
                j = i + 2
                while j < B and jnp.linalg.norm(tensor[j] - current_tensor) < interpolation_threshold:
                     j += 1
                    
                if jnp.linalg.norm(tensor[j] - current_tensor) >= interpolation_threshold:
                    if jnp.abs(interpolation_threshold - jnp.linalg.norm(tensor[j] - current_tensor)) < jnp.abs(interpolation_threshold - jnp.linalg.norm(tensor[j-1] - current_tensor)):
                        closest_j = j
                    else:
                        closest_j = j - 1
                    
                    j = closest_j
                    i = j
                    
                    results.append(current_tensor)
                else:
                    results.append(current_tensor)
                    break
                
        # Append the last tensor
        results.append(tensor[-1])

        return jnp.stack(results)


def group_elements_by_threshold(distance_matrix, threshold):
    groups = []

    def can_add_to_group(group, index):
        for i in group:
            if distance_matrix[i, index] > threshold:
                return False
        return True

    for i in range(distance_matrix.shape[0]):
        placed = False
        for group in groups:
            if can_add_to_group(group, i):
                group.append(i)
                placed = True
                break
        if not placed:
            groups.append([i])

    longest_group = max(groups, key=len)
    return longest_group
    
def adjust_tensor_to_batch_size(tensor, train_batch_size, rng_idx):
    """
    Adjust the size of the tensor such that its number of rows is a multiple of train_batch_size.
    This is done by potentially removing the minimum number of rows necessary.

    Args:
    tensor (jnp.array): Input tensor of shape (N, D).
    train_batch_size (int): The batch size to which the number of rows should be a multiple.

    Returns:
    jnp.array: Adjusted tensor with number of rows N' that is a multiple of train_batch_size.
    """
    N, D = tensor.shape
    # Calculate the remainder when N is divided by train_batch_size
    remainder = N % train_batch_size
    
    if remainder == 0:
        # If N is already a multiple of train_batch_size, no adjustment is needed
        return tensor
    else:
        # Calculate how many rows need to be removed to make it a multiple
        rows_to_remove = remainder

        # Generate random indices to keep, making sure to remove the required number of rows
        keep_indices = random.permutation(rng_idx, N)[rows_to_remove:]

        # Select rows to keep from the tensor
        adjusted_tensor = tensor[keep_indices]

        return adjusted_tensor


def assert_max_distance_less_than_threshold(trajectories, threshold=threshold_value):
    """
    Asserts that the maximum pairwise distance in the last set of points from trajectories is less than a threshold.

    Args:
    trajectories (jnp.array): Array of shape (max_iter+1, N, D) containing trajectories.
    threshold (float): Distance threshold.
    """

    # Extract the last set of points from the trajectories
    last_positions = trajectories[-1, :, :]

    # Calculate the pairwise distance matrix
    # Utilizing broadcasting to compute differences between each pair of points
    diff = last_positions[:, None, :] - last_positions[None, :, :]
    dist_matrix = jnp.sqrt(jnp.sum(diff**2, axis=-1))
    
    good_indices = group_elements_by_threshold(dist_matrix, threshold)
    
    print('number of good indices', len(good_indices))
    if len(good_indices) == 1:
        print("WE HAVE A PROBLEM")
    
    return trajectories[:,good_indices,:]

def grad_descent_on_true_function(x):
    tensor = x
    learning_rate = 0.01
    optimizer = optax.adam(learning_rate)

    # Initialize the optimizer state
    opt_state = optimizer.init(tensor)

    # Compute the gradient of the function with respect to the input tensor
    grad_func = jax.grad(lambda x: jnp.sum(testing_function(x)))

    # Perform 10 steps of gradient descent using Adam
    for step in range(1000):
        # Compute the gradients
        grads = grad_func(tensor)

        # Update the parameters (tensor) using the optimizer
        updates, opt_state = optimizer.update(grads, opt_state)
        tensor = optax.apply_updates(tensor, updates)
        
        norm_delta = jnp.mean(jnp.linalg.norm(x-tensor,axis=1))
        if norm_delta > interpolation_threshold:
            return tensor
    print(f"true grad descent on traj: failed with norm delta of {norm_delta}, need at least {interpolation_threshold}")
    return tensor
        

import jax
import jax.numpy as jnp
import optax
from typing import Callable

def gradient_descent_solver_trajectory_adam(
    fn: Callable,
    z0: jnp.array,
    max_iter: int = 500,
    patience: int = 100  # Add a patience parameter to control the lookback window
):
    print('starting grad descent')
    grad_fn = jax.jit(jax.vmap(jax.value_and_grad(lambda z: jnp.squeeze(fn(z)))))

    
    lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,       # Initial learning rate (before warm-up)
    peak_value=10.0 if not divide_by_10 else 1.0,    # Maximum learning rate after warm-up
    warmup_steps=30,  # Number of warm-up steps
    decay_steps=4000,  # Number of decay steps
    end_value=2.5 if not divide_by_10 else 0.25)         # Final learning rate at the end of training)

# Now, use this schedule in the Adam optimizer
    optimizer = optax.chain(optax.adam(learning_rate=lr_schedule))  # Apply the learning rate schedule

    opt_state = optimizer.init(z0)

    # Initialize list to store the value of z at each iteration
    trajectory = [z0]

    # Initialize variables to track the max distance and value gaps
    vgap = []
    prev_max_distance = None  # To track the previous max distance
    lowest_max_distance = float('inf')  # Track the lowest max distance
    best_idx = 0  # Track the index of the best trajectory
    max_distances = []  # To store max distances
    last_change = 0

    for k in range(max_iter):
        # Compute gradients
        v, gt = grad_fn(z0)
        vg = jnp.mean(v)
        vgap.append(vg)

        # Perform the Adam update
        updates, opt_state = optimizer.update(gt, opt_state)
        z0 = optax.apply_updates(z0, updates)

        # Store the updated z0 in the trajectory list
        trajectory.append(z0)

        # Compute pairwise distances
        pairwise_distances = jnp.sqrt(jnp.sum((z0[:, None, :] - z0[None, :, :]) ** 2, axis=-1))
        max_distance = jnp.max(pairwise_distances)
        max_distances.append(max_distance)
        
        print(k, max_distance)

        # Update the lowest max distance and track the index
        if max_distance < lowest_max_distance:
            lowest_max_distance = max_distance
            best_idx = k
            if lowest_max_distance < threshold_value:
                break

        """
        # Terminate if the max distance hasn't decreased compared to the value from `patience` steps ago
        if len(max_distances) > 2:
            if (max_distances[-1] > max_distances[-2] + 3) and (k > last_change+2):
                if Lr >= 5.0:
                    Lr = Lr * 0.5  # Halve the learning rate
                print(f"1, Reducing learning rate to {Lr} at iteration {k}")
                optimizer = optax.adam(Lr)
                opt_state = optimizer.init(z0)
                last_change = k
            
        if k > (last_change+100):
            min_value = min(max_distances[last_change+2:-100]) if max_distances[last_change+2:-100] else float('inf')
            if max_distance > min_value:
                if Lr >= 5.0:
                    Lr = Lr * 0.5  # Halve the learning rate
                print(f"2, Reducing learning rate to {Lr} at iteration {k}")
                optimizer = optax.adam(Lr)
                opt_state = optimizer.init(z0)
                last_change = k
              
        
        if max_distance < threshold_value:
            print(f"Converged at iteration {k} with max distance {max_distance} less than {threshold_value}")
            break
        """

    data = {
        'vgap': jnp.array(vgap[:best_idx+1]),
        'step': jnp.array([i for i in range(best_idx+1)]),
        'z': jnp.array(trajectory[:best_idx+1])  # Return the trajectory until the best index
    }
    print('max distance', lowest_max_distance)
    print('finished grad descent')
    return data


import jax.numpy as jnp

def filter_trajectories_using_mask(trajectories, threshold=0.01):
    """
    Filter trajectories using a threshold on the distances between consecutive points.

    Args:
    trajectories (jnp.array): Array of shape (max_iter+1, N, D) containing trajectories.
    threshold (float): Threshold value for filtering points based on their consecutive distances.

    Returns:
    jnp.array: Filtered trajectories.
    """
    # Calculate differences between consecutive points
    diffs = jnp.linalg.norm(jnp.diff(trajectories, axis=0), axis=-1)

    # Create mask where differences are greater than the threshold
    mask = diffs > threshold

    # Expand mask dimensions and repeat for each dimension D of the trajectory points
    expanded_mask = jnp.expand_dims(mask, 2).repeat(trajectories.shape[2], axis=2)

    # Append True for the last point in each trajectory to ensure it is always included
    last_points_mask = jnp.ones((1, trajectories.shape[1], trajectories.shape[2]), dtype=bool)
    full_mask = jnp.concatenate([expanded_mask, last_points_mask], axis=0)

    # Apply the mask to filter trajectories
    filtered_trajectories = trajectories[full_mask].reshape(-1, trajectories.shape[2])

    return filtered_trajectories



# different from layer.py, this version allows model to tune its Lipschitz bound nu

class MonLipNet(nn.Module):
    units: Sequence[int]
    tau: jnp.float32 = 10.
    # mu: jnp.float32 = 0.1 # Monotone lower bound
    # nu: jnp.float32 = 10.0 # Lipschitz upper bound (nu > mu)
    # act_fn: Callable = nn.relu

    def get_bounds(self):
        lognu = self.variables['params']['lognu']
        nu = jnp.squeeze(jnp.exp(lognu), 0)
        logmu = self.variables['params']['logmu']
        mu = jnp.squeeze(jnp.exp(logmu), 0)
        return mu, nu, self.tau

    @nn.compact
    def __call__(self, x : jnp.array) -> jnp.array:
        nx = jnp.shape(x)[-1]  
        lognu = self.param('lognu', nn.initializers.constant(jnp.log(0.5)), (1,), jnp.float32)
        nu = jnp.exp(lognu)
        logmu = self.param('logmu', nn.initializers.constant(jnp.log(0.1)), (1,), jnp.float32)
        mu = jnp.exp(logmu)
        #mu = 0.00085501 #nu / self.tau  #this needs to be replaced --> just set mu!!!
        by = self.param('by', nn.initializers.zeros_init(), (nx,), jnp.float32) 
        y = mu * x + by 
        
        Fq = self.param('Fq', nn.initializers.glorot_normal(), (nx, sum(self.units)), jnp.float32)
        fq = self.param('fq', nn.initializers.constant(jnp.linalg.norm(Fq)), (1,), jnp.float32)
        QT = cayley((fq / jnp.linalg.norm(Fq)) * Fq) 
        sqrt_2g, sqrt_g2 = jnp.sqrt(2. * (nu - mu)), jnp.sqrt((nu - mu) / 2.)
        idx, nz_1 = 0, 0 
        zk = x[..., :0]
        Ak_1 = jnp.zeros((0, 0))
        for k, nz in enumerate(self.units):
            Fab = self.param(f'Fab{k}', nn.initializers.glorot_normal(), (nz+nz_1, nz), jnp.float32)
            fab = self.param(f'fab{k}',nn.initializers.constant(jnp.linalg.norm(Fab)), (1,), jnp.float32)
            ABT = cayley((fab / jnp.linalg.norm(Fab)) * Fab)
            ATk, BTk = ABT[:nz, :], ABT[nz:, :]
            QTk_1, QTk = QT[:, idx-nz_1:idx], QT[:, idx:idx+nz]
            STk = QTk @ ATk - QTk_1 @ BTk 
            bk = self.param(f'b{k}', nn.initializers.zeros_init(), (nz,), jnp.float32)
            # use relu activation, no need for psi
            # pk = self.param(f'p{k}', nn.initializers.zeros_init(), (nz,), jnp.float32)
            zk = nn.relu(2 * (zk @ Ak_1) @ BTk + sqrt_2g * x @ STk + bk)
            # zk = nn.relu(zk * jnp.exp(-pk)) * jnp.exp(pk)
            y += sqrt_g2 * zk @ STk.T  
            idx += nz 
            nz_1 = nz 
            Ak_1 = ATk.T     

        return y
        
class BiLipNet(nn.Module):
    units: Sequence[int]
    tau: jnp.float32
    depth: int = 2

    def setup(self):
        uni, mon = [], []
        layer_tau = (self.tau) ** (1/self.depth)
        for _ in range(self.depth):
            uni.append(Unitary())
            mon.append(MonLipNet(self.units, tau=layer_tau))
        uni.append(Unitary())
        self.uni = uni
        self.mon = mon

    def get_bounds(self):
        lipmin, lipmax, tau = 1., 1., 1.
        for k in range(self.depth):
            mu, nu, ta = self.mon[k].get_bounds()
            lipmin *= mu 
            lipmax *= nu 
            tau *= ta 
        return lipmin, lipmax, tau 
    
    @nn.compact
    def __call__(self, x: jnp.array) -> jnp.array:
        for k in range(self.depth):
            x = self.uni[k](x)
            x = self.mon[k](x)
        x = self.uni[self.depth](x)
        return x 
    
class PLNet(nn.Module):
    BiLipBlock: nn.Module

    def gmap(self, x: jnp.array) -> jnp.array:
        return self.BiLipBlock(x)

    def get_bounds(self):
        return self.BiLipBlock.get_bounds()
    
    @nn.compact
    def __call__(self, x: jnp.array) -> jnp.array:
        x = self.BiLipBlock(x)
        y = QuadPotential()(x)

        return y 
    
def get_y(x, batch_size=5000):
    num_batches = int(np.ceil(x.shape[0] / batch_size))
    results = []
    
    for i in range(num_batches):
        batch_x = x[i * batch_size : (i + 1) * batch_size]
        batch_y = jnp.squeeze(testing_function(batch_x))
        results.append(batch_y)
    
    return jnp.concatenate(results, axis=0)

        
def data_gen(
    rng: random.PRNGKey,
    data_dim: int = 20, 
    val_min: float = -5.,
    val_max: float = 5.,
    train_batch_size: int = 2000,
    test_batch_size: int = 2000,
    train_batches: int = 200,
    test_batches: int = 1,
    eval_batch_size: int = 2000,
    eval_batches: int = 100,
):
    rng_train, rng_test, rng_eval = random.split(rng, 3)
    
    xtrain = Sampler(rng_train, 
                     train_batch_size * train_batches,
                     data_dim, 
                     x_min=val_min, x_max=val_max)
    xtest  = Sampler(rng_test, 
                     test_batch_size * test_batches,
                     data_dim, 
                     x_min=val_min, x_max=val_max)
    xeval  = Sampler(rng_eval, 
                     eval_batch_size * eval_batches, 
                     data_dim, 
                     x_min=val_min, x_max=val_max)

    ytrain, ytest, yeval = get_y(xtrain), get_y(xtest), get_y(xeval)
    
    data = {
        "xtrain": xtrain, 
        "ytrain": ytrain, 
        "xtest": xtest, 
        "ytest": ytest, 
        "xeval": xeval,
        "yeval": yeval,
        "train_batches": train_batches,
        "train_batch_size": train_batch_size,
        "test_batches": test_batches,
        "test_batch_size": test_batch_size,
        "eval_batches": eval_batches,
        "eval_batch_size": eval_batch_size,
        "data_dim": data_dim
    }

    return data

import os
import numpy as np
from flax import serialization

def save_model(params, filepath):
    with open(filepath, 'wb') as f:
        f.write(serialization.to_bytes(params))


def train(
    model,
    data,
    name: str = 'bilipnet',
    train_dir: str = '/tmp/results/rosenbrock-nd',
    lr_max: float = 1e-3,
    epochs: int = 1000
):

    ckpt_dir = f'{train_dir}/ckpt'
    os.makedirs(ckpt_dir, exist_ok=True)

    data_dim = data['data_dim']
    train_batches = data['train_batches']
    train_batch_size = data['train_batch_size']

    idx_shp = (train_batches, train_batch_size)
    train_size = train_batches * train_batch_size

    global rng
    rng, rng_model = random.split(rng)
    
    params = model.init(rng_model, jnp.ones(data_dim))
    param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
    print(f'model: {name}, size: {param_count/1000000:.2f}M')

    total_steps = train_batches * epochs
    scheduler = optax.linear_onecycle_schedule(transition_steps=total_steps, 
                                           peak_value=lr_max,
                                           pct_start=0.25, 
                                           pct_final=0.7,
                                           div_factor=10., 
                                           final_div_factor=200.)
    opt = optax.chain(optax.clip_by_global_norm(clip_norm_value), optax.adam(learning_rate=scheduler))
    model_state = train_state.TrainState.create(apply_fn=model.apply,
                                                params=params,
                                                tx=opt)
    
    @jax.jit
    def fitloss(state, params, x, y):
        yh = state.apply_fn(params, x)
        assert yh.shape == y.shape, f"Shape mismatch: predicted {yh.shape}, target {y.shape}"
        loss = (optax.l2_loss(yh, y)).mean()
        return loss
    """
    def fitloss_batched(state, params, x, y, batch_size=2000):
        def batch_loss_mini(x_batch, y_batch):
            yh_batch = state.apply_fn(params, x_batch)
            loss_batch = (optax.l2_loss(yh_batch, y_batch)).mean()
            return loss_batch

        # Determine the number of batches
        num_batches = x.shape[0] // batch_size

        # Split x and y into batches
        x_batches = jnp.array_split(x, num_batches)
        y_batches = jnp.array_split(y, num_batches)
        print("STEP 1")
        # Compute the loss over all batches
        total_loss = 0.0
        for x_batch, y_batch in zip(x_batches, y_batches):
            total_loss += batch_loss_mini(x_batch, y_batch)
        print("STEP 2")
        # Return the average loss over all batches
        return total_loss / num_batches
    """
    @jax.jit
    def train_step(state, x, y):
        grad_fn = jax.value_and_grad(fitloss, argnums=1)
        loss, grads = grad_fn(state, state.params, x, y)
        state = state.apply_gradients(grads=grads)
        return state, loss 
    
    train_loss, val_loss = [], []
    Lipmin, Lipmax, Tau = [], [], []
    for epoch in range(epochs):
        print(jnp.std(data['ytrain']), "STD")

        if epoch > threshold_epoch:
            #tloss_prior_training = fitloss_batched(model_state, model_state.params, data['xtrain'], data['ytrain'])
            #print('loss on sampled trajs prior to training', tloss_prior_training)
            for k in range(1):
                tloss = 0.
                rng, rng_idx = random.split(rng)
                print("STEP 2.5")
                idx = random.permutation(rng_idx, train_batches*train_batch_size)
                print("STEP 3")
                idx = jnp.reshape(idx, (train_batches, train_batch_size))
                for b in range(train_batches):
                    x = data['xtrain'][idx[b, :], :] 
                    y = data['ytrain'][idx[b, :]]
                    model_state, loss = train_step(model_state, x, y)
                    tloss += loss
                tloss_temp = tloss / train_batches
                vloss = fitloss(model_state, model_state.params, data['xtest'], data['ytest'])
                print(k, tloss_temp, vloss)
            #tloss_post_training = fitloss_batched(model_state, model_state.params, data['xtrain'], data['ytrain'])
            #print('debugging: tloss on sampled trajs post training', tloss_post_training)
            
            """
            rng, rng_idx = random.split(rng)
            idx = random.permutation(rng_idx, train_batches*train_batch_size)
            idx = jnp.reshape(idx, (train_batches, train_batch_size))
            for b in range(train_batches)
                x = data['xtrain'][idx[b, :], :] 
                y = data['ytrain'][idx[b, :]]
                model_state, loss = train_step(model_state, x, y)
                tloss += loss
            """
        
        else:
            tloss = 0.
            for b in range(train_batches):
                rng, rng_train = random.split(rng)
                # Use rng_train to create the Sampler
                x = Sampler(rng_train, train_batch_size, data_dim, x_min=-bound, x_max=bound)

                y= get_y(x)
                model_state, loss = train_step(model_state, x, y)
                tloss += loss
            
            
        tloss /= train_batches
        train_loss.append(tloss)
        
        vloss = fitloss(model_state, model_state.params, data['xtest'], data['ytest'])
        val_loss.append(vloss)

        lipmin, lipmax, tau = model.apply(model_state.params, method=model.get_bounds)
        Lipmin.append(lipmin)
        Lipmax.append(lipmax)
        Tau.append(tau)

        print(f'Epoch: {epoch+1:3d} | loss: {tloss:.4f}/{vloss:.4f}, tau: {tau:.1f}, Lip: {lipmin:.3f}/{lipmax:.2f}')

        if epoch <= 200:
            continue
        
        if epoch >= threshold_epoch:
            save_args = orbax_utils.save_args_from_target(model_state)
            orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
            orbax_checkpointer.save(
                f'testing_divergence_{file_name}/{data_dim}_dims_{epoch+1}_epoch_modelandopt',  # Save path
                model_state,  # Save the entire model state, including optimizer
                save_args=save_args
            )
        
        
        global_min_value = model_state.params['params']['QuadPotential_0']['c']
        
        N = 120
        rng, rng_idx = random.split(rng)
        # Sample N points from [-2, 2] with dimension new_data_dim
        new_points = random.uniform(rng_idx, (N, data_dim), minval=-bound, maxval=bound)

        # Get gradient descent trajectories for each point
        trajectories_data = gradient_descent_solver_trajectory_adam(lambda x: model.apply(model_state.params, x),
                                               z0=new_points,
                                               max_iter=4000)
        
        # global_min_value should match trajectories['v']
        #assert jnp.abs(trajectories_data['vgap'][-1] - global_min_value) < 1e-2
        try:
            assert jnp.abs(trajectories_data['vgap'][-1] - global_min_value) <1e-2
        except:
            print(f"error global min gap {jnp.abs(trajectories_data['vgap'][-1] - global_min_value)}")

        
        trajectories = trajectories_data['z']
  
        trajectories = assert_max_distance_less_than_threshold(trajectories, threshold=threshold_value)
        
        global_min = trajectories[-1,0,:]
        zeros_tensor = jnp.zeros_like(global_min)
        distance = jnp.linalg.norm(global_min - zeros_tensor)
        
        print('distance from global min', distance, 'global min', global_min)
        
        
        trajectories = jnp.concatenate([get_interpolated_tensors(trajectories[:,i,:]) for i in range(trajectories.shape[1])],axis=0)
        
        np.save(f'full_trajectories_{epoch+2}_{file_name}.npy',np.array(trajectories_data['z']))
        np.save(f'filtered_trajectories_{epoch+2}_{file_name}.npy', np.array(trajectories))

        rng, rng_idx = random.split(rng)
        trajectories = adjust_tensor_to_batch_size(trajectories, int(generated_data_fraction*train_batch_size),rng_idx)
        
        N, D = trajectories.shape
        
        train_batches = math.ceil(N//int(generated_data_fraction*train_batch_size))
        
        print('new number of batches', train_batches)
        
        #train_batches_grad_descent_true = grad_descent_on_true_function(trajectories)
          
        #rng, rng_idx = random.split(rng)
          
        #train_batches_grad_descent_true = adjust_tensor_to_batch_size(train_batches_grad_descent_true, train_batch_size,rng_idx)
          
        #train_batches +=  train_batches_grad_descent_true.shape[0]//train_batch_size
          
        rng, rng_train = random.split(rng)
        # Use rng_train to create the Sampler
        x = Sampler(rng_train, train_batches*(train_batch_size - int(generated_data_fraction*train_batch_size)), data_dim, x_min=-bound, x_max=bound)
        
        trajectories = jnp.concatenate([trajectories, x],axis=0) #, train_batches_grad_descent_true],axis=0)
        
        
        if epoch < threshold_epoch:
            continue
        

        data['xtrain'] = trajectories
        data['ytrain'] = get_y(trajectories)
        
        
    eloss = fitloss(model_state, model_state.params, data['xeval'], data['yeval'])
    print(f'{name}: eval loss: {eloss:.4f}')

    data['train_loss'] = jnp.array(train_loss)
    data['val_loss'] = jnp.array(val_loss)
    data['lipmin'] = jnp.array(Lipmin)
    data['lipmax'] = jnp.array(Lipmax)
    data['tau'] = jnp.array(Tau)
    data['eval_loss'] = eloss

    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(model_state.params)
    orbax_checkpointer.save(f'{ckpt_dir}/params', model_state.params, save_args=save_args)


lr_max = 1e-2
epochs = 3000
n_batch = 50


root_dir = f'./results/rosenbrock-dim{data_dim}-batch{n_batch}'
#rng = random.PRNGKey(42)
rng, rng_data = random.split(rng, 2)
data= data_gen(rng_data,train_batches=n_batch,test_batches=1,eval_batches=n_batch,data_dim=data_dim, val_min = -bound, val_max = bound)

name = 'BiLipNet'
depth = 4 
for tau in [500]:
    print("TAU", tau)
    train_dir = f'{root_dir}/{name}-{depth}-tau{tau}'
    block = BiLipNet([1024]*8, depth=depth, tau=tau)
    model = PLNet(block)
    train(model, data, name=name, train_dir=train_dir, lr_max=lr_max, epochs=epochs)
