import jax
import jax.numpy as jnp
import pickle

def save_model_to_file(params, file_name):
    """
    Save JAX model parameters to a file.
    
    Args:
    - params: Model parameters, typically a nested dictionary or list of JAX arrays.
    - file_name: The path and name of the file where the parameters should be saved.
    
    Returns:
    - file_name: Returns the path where the model was saved (useful for chaining operations).
    """
    # Convert JAX arrays to numpy arrays for serialization
    params_numpy = jax.tree_map(lambda x: jnp.asarray(x), params)

    with open(file_name, 'wb') as f:
        pickle.dump(params_numpy, f)

    return file_name
