import jax
import jax.numpy as jnp
import einops
from diffuse.sde import SDE, LinearSchedule, SDEState
from diffuse.conditional import CondSDE
from diffuse.unet import UNet
from diffuse.optimizer import ImplicitState, impl_step
from diffuse.filter import generate_cond_sample
import optax

# Load necessary data and initialize variables
data = jnp.load("dataset/mnist.npz")
xs = data["X"]
xs = einops.rearrange(xs, "b h w -> b h w 1")

tf = 2.0
batch_size = 256
n_epochs = 3000
n_t = 256
dt = tf / n_t

key = jax.random.PRNGKey(0)

# Define beta schedule and SDE
beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)
sde = SDE(beta)

# Load trained neural network
nn_unet = UNet(dt, 64, upsampling="pixel_shuffle")
nn_trained = jnp.load("ann_2999.npz", allow_pickle=True)
nn_params = nn_trained["params"].item()


def nn_score(x, t):
    return nn_unet.apply(nn_params, x, t)


# Set up conditional SDE
x = xs[0]
mask = SquareMask(10, x.shape)
xi = jnp.array([10.0, 20.0])
y = measure(xi, x, mask)
cond_sde = CondSDE(beta=beta, mask=mask, tf=2.0, score=nn_score)

# Generate conditional sample (assuming you have this function defined)
res = generate_cond_sample(y, xi, key, cond_sde, x.shape)

# Prepare data for implicit optimization
past_y = mask.measure(xi, x)
y = jax.vmap(mask.measure, in_axes=(None, 0))(xi, xs[0:40])
n_t = 29
ts = jnp.linspace(0, tf, n_t)

key_noise = jax.random.split(key, n_t)
state_0 = SDEState(past_y, jnp.zeros_like(past_y))
past_y = jax.vmap(sde.path, in_axes=(0, None, 0))(key_noise, state_0, ts)

n_particles = 10
n_contrast = 10
thetas = res[1][0][:, 0:39]
cntrst_thetas = res[1][0][:, 50:100]

# Set up optimizer
design = xi
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(design)

# Create initial state
initial_state = ImplicitState(thetas, cntrst_thetas, design, opt_state)

# Set up parameters for impl_step
key_step = jax.random.PRNGKey(42)
ts = jnp.linspace(0, tf, n_t)
dt = tf / (n_t - 1)

# Run impl_step
new_state = impl_step(initial_state, key_step, past_y, cond_sde, optimizer, ts, dt)

# Print results
print("New state:", new_state)
