import jax
import ot_jax.optimal_transport.jax_wasserstein as jw
import numpy as np

# Set JAX to use the last GPU (GPU 3)
jax.config.update('jax_platform_name', 'gpu')
jax.config.update('jax_default_device', jax.devices('gpu')[3])

vols = np.load(open("data/rotmol3d/rotated_EMDB2660_mask_radius=128_downscale_factor=8_n_angles=36.npy", 'rb'))
# print(vols.shape)
print(jw.weighted_cost_upper_bound(vols[0], vols[1], p=2, scale_factor=2))
# This returns Array(0., dtype=float64, weak_type=True)
