import argparse
import os
import socket
import time
from timeit import default_timer as timer
from typing import Tuple

# IMPORTANT: set before importing geomstats/riemannian_score_sde
os.environ["GEOMSTATS_BACKEND"] = "jax"

import haiku as hk
import jax
import jax.numpy as jnp
import jax.random as jr
import optax
import sympy as sp
import yaml
from functools import partial
from hydra.utils import get_class, instantiate
from jax import grad, jacrev, jit, vmap
from jax.scipy.linalg import expm
from omegaconf import OmegaConf
from tqdm import tqdm
import wandb
try:
    KeyArray = jax.random.KeyArray
except AttributeError:
    KeyArray = getattr(jax, "Array", jnp.ndarray)
from archs import ModifiedMlp

import sys
sys.path.append('riemannian-score-sde')
from score_sde.datasets import DataLoader, TensorDataset, random_split
from score_sde.losses import get_ema_loss_step_fn
from score_sde.utils import ParametrisedScoreFunction, TrainState, batch_mul, restore, save
from score_sde.models.flow import SDEPushForward  # type only / isinstance check

from riemannian_score_sde.losses import get_dsm_loss_fn
from riemannian_score_sde.utils.normalization import compute_normalization
from riemannian_score_sde.utils.vis import plot, plot_ref

def log_print(msg: str):
    print(msg, flush=True)

parser = argparse.ArgumentParser(description='Train pinn')
parser.add_argument('--exp_name', type=str, default='tmp',
                    help='exp name')
parser.add_argument('--use_pinn', type=bool, default=True, action=argparse.BooleanOptionalAction,
                    help='use_pinn')
parser.add_argument('--pinn_name', type=str, default='exp013-9',
                    help='use_pinn')
parser.add_argument('--n_max', type=int, default=-1,
                    help='n_max')
parser.add_argument('--beta_f', type=float, default=5,
                    help='beta_f')
parser.add_argument('--seed', type=int, default=None,
                    help='seed')
parser.add_argument('--K', type=int, default=32,
                    help='K')
parser.add_argument('--mcmc_t_thresh', type=float, default=0.,
                    help='mcmc_t_thresh')
args = parser.parse_args()

if args.seed is None:
    args.seed = hash(str(time.time())+args.exp_name) % (2**32)

exp_name = args.pinn_name
eqn_name = 'FP_SO3_jax'
with open(f'{eqn_name}/data-pinn/{exp_name}/args.yaml', 'r') as f:
    pinn_args = yaml.load(f,Loader=yaml.FullLoader)
    pinn_args = argparse.Namespace(**pinn_args)
variables = jnp.load(f'{eqn_name}/data-pinn/{exp_name}/last_epoch.npy',allow_pickle=True).item()

pinn  = ModifiedMlp(num_layers = pinn_args.n_layers-1,
                    hidden_dim = pinn_args.n_hidden,
                    out_dim = 1,
                    fourier_emb = {
                        'embed_scale':pinn_args.sigma_fourier,
                        'embed_dim':pinn_args.n_fourier
                    },
                    reparam = {
                        'type':'weight_fact',
                        'mean':1.,
                        'stddev':0.1,
                    }
                    )

def f_0(variables,x,t):
    xt = jnp.concatenate([x,t],axis=-1)
    u_0 = pinn.apply(variables, xt)
    return u_0,u_0

def f_t(variables,x,t):
    u_t,u_0 = jacrev(f_0,argnums=2,has_aux=True)(variables,x,t)
    return u_t,(u_t,u_0)

def f_x(variables,x,t):
    u_x,u_0 = jacrev(f_0,argnums=1,has_aux=True)(variables,x,t)
    return u_x,(u_x,u_0)

def f_xx(variables,x,t):
    u_xx,(u_x,u_0) = jacrev(f_x,argnums=1,has_aux=True)(variables,x,t)
    return u_xx, (u_xx,u_x,u_0)


import sympy as sp
n = 3
order = 15

r = sp.symbols('r')
D = (sp.sin(r) / r) ** (n-1)
DD = (1/D*sp.diff(D,r)).cancel().simplify()

def lap(u):
    return sp.diff(u,r,2) + ((n-1)/r + DD)* sp.diff(u,r)

s = sp.symbols('s')
def iterate(u,i):
    integrand = (sp.sqrt(D)*lap(u)).cancel().subs(r,s) * s ** (i-1)
    integral = sp.integrate(integrand.cancel(),s,risch=True).cancel().simplify()
    # integral = integral.subs(s,r) - integral.limit(s,0)
    integral = integral.subs(s,r)
    return ((r**(-i)) / sp.sqrt(D) * integral).cancel().simplify()


s = sp.symbols('s')
def iterate_approx(u,i):
    integrand = (sp.sqrt(D)*lap(u)).cancel().subs(r,s) * s ** (i-1)
    integrand = integrand.series(s,0,order).removeO()
    integral = sp.integrate(integrand.cancel(),s).cancel().simplify()
    # integral = integral.subs(s,r) - integral.limit(s,0)
    integral = integral.subs(s,r)
    return ((r**(-i)) / sp.sqrt(D) * integral).cancel().simplify().series(r,0,10).removeO()


u_0 = 1/sp.sqrt(D).evalf()
u_1 = iterate(u_0,1).evalf()
u_2 = iterate_approx(u_1,2).evalf()
u_3 = iterate_approx(u_2,3).evalf()
u_4 = iterate_approx(u_3,4).evalf()

t = sp.symbols('t')
G = (4*sp.pi*t)**(-n/2) * sp.exp(-r**2/(4*t))
logG = (-n/2) * sp.log(4*sp.pi*t) - r**2/(4*t)

S0 = G * u_0
S1 = G * (u_0 + u_1*t)
S2 = G * (u_0 + u_1*t + u_2*t**2)
S3 = G * (u_0 + u_1*t + u_2*t**2 + u_3*t**3)
S4 = G * (u_0 + u_1*t + u_2*t**2 + u_3*t**3 + u_4*t**4)

logS0 = logG + sp.log(u_0).evalf()
logS1 = logG + sp.log(u_0 + u_1*t).evalf()
logS2 = logG + sp.log(u_0 + u_1*t + u_2*t**2).evalf()
logS3 = logG + sp.log(u_0 + u_1*t + u_2*t**2 + u_3*t**3).evalf()
logS4 = logG + sp.log(u_0 + u_1*t + u_2*t**2 + u_3*t**3 + u_4*t**4).evalf()

f_logS4 = sp.lambdify((r,t),logS4.simplify().expand().evalf(),modules = "jax")
# use metric tr(uv^T)

def matrix_to_quaternion(R):
    trace = R[0,0] + R[1,1] + R[2,2]

    w = 0.5 * jnp.sqrt(jnp.maximum(1.0 + trace, 0.0))
    x = (R[2,1] - R[1,2]) / (4.0 * w)
    y = (R[0,2] - R[2,0]) / (4.0 * w)
    z = (R[1,0] - R[0,1]) / (4.0 * w)

    return jnp.array([w, x, y, z])

def so3_dist(R1, R2, arccos = True):
    
    R = R1@R2.T
    trace = R[0,0] + R[1,1] + R[2,2]
    w = 0.5 * (trace - 1)
    if arccos:
        return jnp.arccos(w)/2
    else:
        return w
    
    
def quaternion_to_matrix(q):
    w, x, y, z = q
    ww, xx, yy, zz = w*w, x*x, y*y, z*z
    wx, wy, wz = w*x, w*y, w*z
    xy, xz, yz = x*y, x*z, y*z

    R = jnp.array([
        [ww + xx - yy - zz,   2*(xy - wz),       2*(xz + wy)],
        [2*(xy + wz),         ww - xx + yy - zz, 2*(yz - wx)],
        [2*(xz - wy),         2*(yz + wx),       ww - xx - yy + zz]
    ])
    return R

def sample_so3(key, n_samples):
    q_raw = jax.random.normal(key, shape=(n_samples, 4))
    q_unit = q_raw / jnp.linalg.norm(q_raw, axis=1, keepdims=True)
    R = jax.vmap(quaternion_to_matrix)(q_unit)  # shape (n_samples, 3, 3)
    
    return R


def project_elementwise(x, v):
    M = jnp.transpose(x) @ v
    M_skew = 0.5 * (M - jnp.transpose(M))
    return x @ M_skew

from jax.scipy.linalg import expm

def exp_elementwise(x, v):
    A = x.T @ v
    M = expm(A)
    return x @ M

project = vmap(project_elementwise)
exp = vmap(exp_elementwise)

def compute_logp_pinn(x_0,x,t):
    q = matrix_to_quaternion(x_0@x.T)
    _,u = f_0(variables, q, t)
    return u[...,0]


def correct_so3(M):
    """
    Project a 3×3 matrix M onto SO(3) via SVD, ensuring det=+1.
    """
    U, S, Vt = jnp.linalg.svd(M, full_matrices=True)
    # Compute the product U V^T
    UVt = U @ Vt
    # Enforce a positive determinant by flipping the last column of U if needed
    det_val = jnp.linalg.det(UVt)
    sign_det = jnp.sign(det_val)

    # Construct correction diag(1,1,sign)
    correction = jnp.diag(jnp.array([1.0, 1.0, sign_det]))
    R = U @ correction @ Vt
    return R

def compute_logp_sp(x_0,x,t):
    r = jnp.clip(so3_dist(x_0,x), 1e-7, jnp.pi - 0.001)
    u = f_logS4(r,t[...,0])
    return u

t_threshold = 0.3

def compute_logp_grad_pinn(x_0,x,t):
    #  # factor of 2 comes from the choice of metric tr(uv^T)/2
    # return project_elementwise(x,grad(compute_logp_pinn,argnums=1)(x_0,x,t)) * 2
    return project_elementwise(x,grad(compute_logp_pinn,argnums=1)(x_0,x,t))

def compute_logp_grad_sp(x_0,x,t):
     # factor of 2 comes from the choice of metric tr(uv^T)/2
    # return project_elementwise(x,grad(compute_logp_sp,argnums=1)(x_0,x,t)) * 2
    return project_elementwise(x,grad(compute_logp_sp,argnums=1)(x_0,x,t))

@jit
def compute_logp(x_0,x,t):
    t = t.reshape(-1,1)
    u_pinn = vmap(compute_logp_pinn)(x_0,x,t)
    u_sp = vmap(compute_logp_sp)(x_0,x,t)
    u = jnp.where(t[:,0]>t_threshold,u_pinn,u_sp)
    return u

# @jit
def compute_logp_grad(x_0,x,t):
    t = t.reshape(-1,1)
    u_grad_pinn = vmap(compute_logp_grad_pinn)(x_0,x,t)
    u_grad_sp = vmap(compute_logp_grad_sp)(x_0,x,t)
    u_grad = jnp.where(t[:,:,None]>t_threshold,u_grad_pinn,u_grad_sp)
    u_grad = jnp.where(jnp.abs(vmap(so3_dist,in_axes=(0,0,None))(x_0,x,False)[:,None,None]) > 1-1e-7, jnp.zeros_like(u_grad), u_grad)
    return u_grad


# jax.config.update('jax_default_matmul_precision', 'float32')

# @jit(static_argnames=["n_iter", "t_thresh", "correct_freq"])
# @partial(jax.jit, static_argnames=["n_iter", "t_thresh", "correct_freq"])
def compute_mcmc(key, x_0, t, n_iter=25, t_thresh=0., correct_freq=10):
    n_data = t.shape[0]
    t = t.reshape(n_data, 1, 1)

    # x = exp(x_0, project(x_0, jr.normal(key, (n_data, 3, 3)) * (t * 2) ** 0.5))
    x = exp(x_0, project(x_0, jr.normal(key, (n_data, 3, 3)) * (t * 16) ** 0.5))
    logp_x = compute_logp(x_0, x, t.reshape(n_data, 1))

    def body_fn(i, state):
        key, x, logp_x = state
        key, subkey1,subkey2 = jax.random.split(key,3)
        # x_new = exp(x, project(x, jr.normal(subkey, (n_data, 3, 3)) * (t * 2) ** 0.5))
        x_new = exp(x, project(x, jr.normal(subkey1, (n_data, 3, 3)) * (t * 16) ** 0.5))
        logp_x_new = compute_logp(x_0, x_new, t.reshape(n_data, 1))
        accept_ratio = jnp.exp(logp_x_new - logp_x)
        accept = jr.uniform(subkey2, (n_data,)) < accept_ratio
        x = jnp.where(accept[:, None, None], x_new, x)
        logp_x = jnp.where(accept, logp_x_new, logp_x)

        cond = jnp.logical_or(jnp.mod(i + 1, correct_freq) == 0,(i + 1) == n_iter)
        x = jax.lax.cond(
            cond,
            lambda _: vmap(correct_so3)(x),
            lambda _: x,
            operand=None
        )
        return key, x, logp_x

    key, x, logp_x = jax.lax.fori_loop(0, n_iter, body_fn, (key, x, logp_x))

    x = jnp.where(t < t_thresh, x_0, x)
    logp_x = jnp.where(t[:, 0] < t_thresh, 1.0, logp_x)

    return x, logp_x


mcmc_sample = jax.jit(compute_mcmc, static_argnames=["n_iter", "t_thresh", "correct_freq"])

def get_dsm_loss_fn_pinn(
    pushforward: SDEPushForward,
    model: ParametrisedScoreFunction,
    train: bool = True,
    like_w: bool = True,
    eps: float = 1e-3,
    s_zero=True,
    **kwargs
):
    sde = pushforward.sde

    def loss_fn(
        rng: KeyArray, params: dict, states: dict, batch: dict
    ) -> Tuple[float, dict]:
        score_fn = sde.reparametrise_score_fn(model, params, states, train, True)
        y_0, context = pushforward.transform.inv(batch["data"]), batch["context"]

        rng, step_rng = jr.split(rng)
        # uniformly sample from SDE timeframe
        t = jr.uniform(step_rng, (y_0.shape[0],), minval=sde.t0 + eps, maxval=sde.tf)
        rng, step_rng = jr.split(rng)

        # sample p(y_t | y_0)
        # compute $\nabla \log p(y_t | y_0)$
        if s_zero:  # l_{t|0}
            # y_t = sde.marginal_sample(step_rng, y_0, t)
            # print('y_0,y_t,t',y_0.shape,y_t.shape,t.shape)
            with jax.default_matmul_precision("float32"):
                y_t,_ = mcmc_sample(step_rng, y_0,sde.beta_schedule.rescale_t(t[:,None]) / 16, t_thresh = args.mcmc_t_thresh, n_iter=25)    
            # if "n_max" in kwargs and kwargs["n_max"] <= -1:
            #     get_logp_grad = lambda y_0, y_t, t: sde.varhadan_exp(
            #         y_0, y_t, jnp.zeros_like(t), t
            #     )[1]
            # else:
            #     get_logp_grad = lambda y_0, y_t, t: sde.grad_marginal_log_prob(
            #         y_0, y_t, t, **kwargs
            #     )[1]
            # logp_grad = get_logp_grad(y_0, y_t, t)
            # print('logp_grad',logp_grad.shape)
            logp_grad = compute_logp_grad(y_0, y_t, sde.beta_schedule.rescale_t(t[:,None]) / 16)
            std = jnp.expand_dims(sde.marginal_prob(jnp.zeros_like(y_t), t)[1], -1)
        else:  # l_{t|s}
            y_t, y_hist, timesteps = sde.marginal_sample(
                step_rng, y_0, t, return_hist=True
            )
            y_s = y_hist[-2]
            delta_t, logp_grad = sde.varhadan_exp(y_s, y_t, timesteps[-2], timesteps[-1])
            delta_t = t  # NOTE: works better?
            std = jnp.expand_dims(sde.marginal_prob(jnp.zeros_like(y_t), delta_t)[1], -1)
        
        # compute approximate score at y_t
        
        score, new_model_state = score_fn(y_t, t, context, rng=step_rng)
        score = score.reshape(y_t.shape)

        if not like_w:
            score = batch_mul(std, score)
            logp_grad = batch_mul(std, logp_grad)
            losses = sde.manifold.metric.squared_norm(score - logp_grad, y_t)
        else:
            # compute $E_{p{y_0}}[|| s_\theta(y_t, t) - \nabla \log p(y_t | y_0)||^2]$
            g2 = sde.coefficients(jnp.zeros_like(y_0), t)[1] ** 2
            losses = sde.manifold.metric.squared_norm(score - logp_grad, y_t) * g2

        loss = jnp.mean(losses)
        return loss, new_model_state

    return loss_fn


cfg = OmegaConf.load("riemannian-score-sde/configs/config_so3.yaml")

cfg.dataset.data_dir = 'data'
cfg.flow.beta_schedule.beta_f = args.beta_f
path = f'experiment_results_so3/{args.exp_name}'
os.makedirs(path, exist_ok=True)

wandb.init(
    project='so3-diffusion',
    config=vars(args),
    name=args.exp_name
)

with open(os.path.join(path,'args.yaml'),'w') as f:
    yaml.dump(vars(args),f)
cfg.dataset.K = args.K

cfg.seed = args.seed

cfg.ckpt_dir = f"{path}/ckpt" 

def train(train_state):
    loss = instantiate(
        cfg.loss, pushforward=pushforward, model=model, eps=cfg.eps, train=True
    )
    # import inspect
    # print(inspect.getsource(loss))
    if args.use_pinn:
        loss = get_dsm_loss_fn_pinn(like_w = False,
                            eps = 1e-3,
                            pushforward = pushforward,
                            model = model,
                            train = True,
                            thresh = 0.,
                            n_max = -1,
                            )
    else:
        loss = get_dsm_loss_fn(like_w = False,
                            eps = 1e-3,
                            pushforward = pushforward,
                            model = model,
                            train = True,
                            thresh = 0.,
                            n_max = args.n_max,
                            )
        
    
    train_step_fn = get_ema_loss_step_fn(loss, optimizer=optimiser, train=True)
    train_step_fn = jax.jit(train_step_fn)

    rng = train_state.rng
    t = tqdm(
        range(train_state.step, cfg.steps),
        total=cfg.steps - train_state.step,
        bar_format="{desc}{bar}{r_bar}",
        mininterval=1,
    )
    train_time = timer()
    total_train_time = 0
    for step in t:
        data, context = next(train_ds)
        batch = {"data": data, "context": context}
        rng, next_rng = jax.random.split(rng)
        (rng, train_state), loss = train_step_fn((next_rng, train_state), batch)
        if jnp.isnan(loss).any():
            print("Loss is nan")
            return train_state, False

        if step % 50 == 0:
            wandb.log({"train/loss": float(loss)}, step=step)
            t.set_description(f"Loss: {loss:.3f}")

        if step > 0 and step % cfg.val_freq == 0:
            wandb.log({"train/time_per_it": float((timer() - train_time) / cfg.val_freq)}, step=step)
            total_train_time += timer() - train_time
            save(ckpt_path, train_state)
            eval_time = timer()
            if cfg.train_val:
                evaluate(train_state, "val", step)
                wandb.log({"val/time_per_it": float((timer() - eval_time))}, step=step)
            # if cfg.train_plot:
            #     generate_plots(train_state, "val", step=step)
            train_time = timer()

    wandb.log({"train/total_time": total_train_time}, step=step)
    return train_state, True

def evaluate(train_state, stage, step=None):
    print("Running evaluation")
    dataset = eval_ds if stage == "val" else test_ds

    model_w_dicts = (model, train_state.params_ema, train_state.model_state)
    likelihood_fn = pushforward.get_log_prob(model_w_dicts, train=False)
    likelihood_fn = jax.jit(likelihood_fn)

    logp, nfe, N = 0.0, 0.0, 0
    tot = 0

    if hasattr(dataset, "__len__"):
        for batch in dataset:
            logp_step, nfe_step = likelihood_fn(*batch)
            logp += logp_step.sum()
            nfe += nfe_step
            N += logp_step.shape[0]
    else:
        dataset.batch_dims = [cfg.eval_batch_size]
        samples = round(20_000 / cfg.eval_batch_size)
        for i in range(samples):
            batch = next(dataset)
            logp_step, nfe_step = likelihood_fn(*batch)
            logp += logp_step.sum()
            nfe += nfe_step
            N += logp_step.shape[0]
            tot += logp_step.shape[0]
        dataset.batch_dims = [cfg.batch_size]

    logp /= N
    nfe /= len(dataset) if hasattr(dataset, "__len__") else samples

    wandb.log({f"{stage}/logp": logp}, step=step)
    print(f"{stage}/logp = {logp:.3f}")
    wandb.log({f"{stage}/nfe": nfe}, step=step)
    print(f"{stage}/nfe = {nfe:.1f}")

    if stage == "test":  # Estimate normalisation constant
        default_context = None
        Z = compute_normalization(
            likelihood_fn, data_manifold, context=default_context
        )
        print(f"Z = {Z:.2f}")
        wandb.log({f"{stage}/Z": Z}, step=step)


### Main
print("Stage : Startup")
print(f"Jax devices: {jax.devices()}")
run_path = os.getcwd()
print(f"run_path: {run_path}")
print(f"hostname: {socket.gethostname()}")
ckpt_path = os.path.join(run_path, cfg.ckpt_dir)
os.makedirs(ckpt_path, exist_ok=True)


print("Stage : Instantiate model")
rng = jax.random.PRNGKey(cfg.seed)
data_manifold = instantiate(cfg.manifold)
transform = instantiate(cfg.transform, data_manifold)
model_manifold = transform.domain
beta_schedule = instantiate(cfg.beta_schedule)
flow = instantiate(cfg.flow, manifold=model_manifold, beta_schedule=beta_schedule)
base = instantiate(cfg.base, model_manifold, flow)
pushforward = instantiate(cfg.pushf, flow, base, transform=transform)

print("Stage : Instantiate dataset")
rng, next_rng = jax.random.split(rng)
dataset = instantiate(cfg.dataset, rng=next_rng)

if isinstance(dataset, TensorDataset):
    # split and wrapp dataset into dataloaders
    train_ds, eval_ds, test_ds = random_split(
        dataset, lengths=cfg.splits, rng=next_rng
    )
    train_ds, eval_ds, test_ds = (
        DataLoader(train_ds, batch_dims=cfg.batch_size, rng=next_rng, shuffle=True),
        DataLoader(eval_ds, batch_dims=cfg.eval_batch_size, rng=next_rng),
        DataLoader(test_ds, batch_dims=cfg.eval_batch_size, rng=next_rng),
    )
    print(
        f"Train size: {len(train_ds.dataset)}. Val size: {len(eval_ds.dataset)}. Test size: {len(test_ds.dataset)}"
    )
else:
    train_ds, eval_ds, test_ds = dataset, dataset, dataset

print("Stage : Instantiate vector field model")

def model(y, t, context=None):
    """Vector field s_\theta: y, t, context -> T_y M"""
    output_shape = get_class(cfg.generator._target_).output_shape(model_manifold)
    score = instantiate(
        cfg.generator,
        cfg.architecture,
        cfg.embedding,
        output_shape,
        manifold=model_manifold,
    )
    # TODO: parse context into embedding map
    if context is not None:
        t_expanded = jnp.expand_dims(t.reshape(-1), -1)
        if context.shape[0] != y.shape[0]:
            context = jnp.repeat(jnp.expand_dims(context, 0), y.shape[0], 0)
        context = jnp.concatenate([t_expanded, context], axis=-1)
    else:
        context = t
    return score(y, context)

model = hk.transform_with_state(model)

rng, next_rng = jax.random.split(rng)
t = jnp.zeros((cfg.batch_size, 1))
data, context = next(train_ds)
params, state = model.init(rng=next_rng, y=transform.inv(data), t=t, context=context)

print("Stage : Instantiate optimiser")
schedule_fn = instantiate(cfg.scheduler)
optimiser = optax.chain(instantiate(cfg.optim), optax.scale_by_schedule(schedule_fn))
opt_state = optimiser.init(params)

if cfg.resume or cfg.mode == "test":  # if resume or evaluate
    train_state = restore(ckpt_path)
else:
    rng, next_rng = jax.random.split(rng)
    train_state = TrainState(
        opt_state=opt_state,
        model_state=state,
        step=0,
        params=params,
        ema_rate=cfg.ema_rate,
        params_ema=params,
        rng=next_rng,  # TODO: we should actually use this for reproducibility
    )
    save(ckpt_path, train_state)

if cfg.mode == "train" or cfg.mode == "all":
    # if train_state.step == 0 and cfg.test_test:
    #     evaluate(train_state, "test", step=cfg.steps)
    # if train_state.step == 0 and cfg.test_plot:
        # generate_plots(train_state, "test", step=-1)
    print("Stage : Training")
    train_state, success = train(train_state)
if cfg.mode == "test" or (cfg.mode == "all" and success):
    print("Stage : Test")
    if cfg.test_val:
        evaluate(train_state, "val", step=cfg.steps)
    if cfg.test_test:
        evaluate(train_state, "test", step=cfg.steps)
    # if cfg.test_plot:
    #     generate_plots(train_state, "test", step=cfg.steps)
    success = True



