from functools import partial
import logging
import pprint
import sys

import hydra
from omegaconf import OmegaConf
import torch
import torch.utils.data.dataloader
import jax
import jax.numpy as jnp
import numpy as np
import orbax
import optax
import diffrax
import lineax
import lightning.pytorch as pl

import conf.conf
import conf.dataset
import conf.model
import userdiffusion.unet
import userfm.callbacks
import userfm.datasets
import userfm.diffusion
import userfm.event_constraints
import userfm.loggers
import userfm.steering_jax
import userfm.sde_diffusion
import userfm.utils


# turn off orbax info logging
logging.getLogger('absl').setLevel(logging.WARNING)

log = userfm.utils.getLoggerByFilename(__file__)


class JaxLightning(pl.LightningModule):
    def __init__(self, cfg, key, dataloaders, train_data_std, model):
        super().__init__()
        self.automatic_optimization = False

        self.cfg = cfg
        (
            self.key_model_init,
            self.key_train,
            self.key_predict
        ) = jax.random.split(key, 3)
        self.dataloaders = dataloaders
        self.train_data_std = train_data_std
        self.model = model

        if isinstance(self.cfg.get_model().probability_path, conf.model.VarianceExploding):
            self.diffusion = userfm.sde_diffusion.get_sde_diffusion(self.cfg.get_model().probability_path)

        self.ema_ts = self.cfg.get_model().epoch_count / self.cfg.get_model().ema_folding_count

        self.loss_and_grad = jax.value_and_grad(self.loss, argnums=2, has_aux=True)

    def __hash__(self):
        return hash(id(self))

    @property
    def noise_std(self):
        if isinstance(self.cfg.get_model().probability_path, conf.model.VarianceExploding):
            return self.diffusion.sigma(1)
        elif isinstance(self.cfg.get_model().probability_path, conf.model.ConditionalOT):
            return 1
        else:
            raise ValueError(f'Noise std. dev. unknown for probability path: {self.cfg.get_model().probability_path}')

    def setup(self, stage):
        if stage == 'fit':
            key_model_init, self.key_model_init = jax.random.split(self.key_model_init)
            self.params = self.model_init(
                key_model_init,
                (self.cfg.dataset.batch_size, self.cfg.dataset.len_trajectory, self.cfg.dataset.dim),
                self.model
            )
            self.params_ema = self.params
        elif stage == 'validate':
            if self.cfg.predict:
                self.event_constraint = userfm.event_constraints.get_event_constraint(self.cfg.dataset)
                self.times = self.get_sample_times(self.cfg.steering.sampling_step_count)
                dt = self.times[1] - self.times[0]
                # prevent repeated compilation
                if isinstance(self.cfg.steering, conf.model.SourceParallelTempering):
                    def push_forward(y0):
                        return self.event_constraint.potential_fn(self.sample(self.times[0], self.times[-1], dt, y0).ys[-1], power=self.cfg.steering.penalization_power)
                    self.potential_fn = jax.jit(push_forward)
                    self.inverse_temperature = self.get_inverse_temperature(self.cfg.steering.tilt, self.cfg.steering.chain_count)
                    self.step_size_angle = self.get_step_size_angle(self.cfg.steering.chain_count)
                    self.kernel = userfm.steering_jax.KernelPreconditionedCrankNicolson(std=self.noise_std)
                    x_shape = (self.cfg.dataset.batch_size, self.cfg.dataset.len_trajectory, self.cfg.dataset.dim)
                    self.ptmcmc = userfm.steering_jax.PTMCMC(
                        self.kernel, self.potential_fn, self.cfg.steering.update_count,
                        shape=x_shape,
                        inverse_temperature=self.inverse_temperature,
                        step_size_angle=self.step_size_angle,
                    )
                elif isinstance(self.cfg.steering, conf.model.FeynmannKac):
                    self.feynmann_kac_intermediate_reward_fn = jax.jit(jax.vmap(
                        self.get_feynmann_kac_intermediate_reward_fn(partial(self.event_constraint.reward_fn, power=self.cfg.steering.penalization_power)),
                        in_axes=(0, None, None, None, 1), out_axes=1,
                    ))
                    self.feynmann_kac_potential_fn = jax.jit(jax.vmap(self.get_feynmann_kac_potential_fn(), in_axes=(1, 1), out_axes=1))
                    def propose(key, t0, t1, dt, sample):
                        return self.sample(t0, t1, dt, sample, as_score=True, key=key).ys[-1]
                    self.propose_fn = jax.jit(jax.vmap(propose, in_axes=(0, None, None, None, 1), out_axes=1))
                    def adaptive_resampling(key, ensemble, potentials):
                        probabilities = potentials / potentials.sum()
                        effective_sample_size = 1 / jnp.square(probabilities).sum()
                        return jnp.where(
                            effective_sample_size < self.cfg.steering.ensemble_size / 2,
                            jax.random.choice(key, ensemble, shape=ensemble.shape[:1], p=probabilities),
                            ensemble,
                        )
                    self.resampler = jax.jit(jax.vmap(adaptive_resampling))
        else:
            raise ValueError(f'Unknown stage: {stage}')

    def model_init(self, key, x_shape, model):
        x = jnp.ones(x_shape)
        t = jnp.ones(x_shape[0])
        params = model.init(key, x=x, t=t, train=False)
        return params

    def configure_optimizers(self):
        learning_rate_scheduler = optax.exponential_decay(
            init_value=self.cfg.get_model().lr,
            transition_steps=512,  # Number of steps before decay
            decay_rate=self.cfg.get_model().lr_decay,  # Decay factor
            staircase=True  # Whether to use staircase decay
        )
        self.optimizer = optax.adam(learning_rate=learning_rate_scheduler)
        self.opt_state = self.optimizer.init(self.params)

    def train_dataloader(self):
        return self.dataloaders['train']

    def training_step(self, batch, batch_idx):
        key_training_step, self.key_train = jax.random.split(self.key_train)
        loss, monitors, self.params, self.params_ema, self.opt_state = self.step(
            key_training_step, batch,
            self.params, self.params_ema,
            self.opt_state,
        )
        # use same key_training_step to ensure identical sampling
        loss_ema, monitors_ema = self.loss(key_training_step, batch, self.params_ema)
        self.optimizers().step()  # increment global step for PyTorch Lightning logging and checkpointing
        outputs = dict(
            loss=loss,
            loss_ema=loss_ema,
            monitors=monitors,
            monitors_ema=monitors_ema,
        )
        return jax.tree.map(lambda x: torch.tensor(x.item()), outputs)

    def val_dataloader(self):
        if self.cfg.fit:
            return self.dataloaders['val']
        elif self.cfg.predict:
            return self.predict_dataloader()

    def validation_step(self, batch, batch_idx):
        if self.cfg.fit:
            if self.trainer.sanity_checking:
                return dict(loss_val=torch.tensor(-1.))
            else:
                return dict(loss_val=self.trainer.callback_metrics['train_loss_ema'])
        elif self.cfg.predict:
            return self.predict_step(batch, batch_idx)

    def predict_dataloader(self):
        return torch.utils.data.DataLoader(
            jax.random.split(self.key_predict, self.cfg.steering.batch_count),
            collate_fn=lambda x: x[0],
        )

    def predict_step(self, key, batch_idx):
        if isinstance(self.cfg.steering, conf.model.NoSteering):
            (
                key_initial,
                key_solver,
            ) = jax.random.split(key)
            t0 = self.times[0]
            t1 = self.times[-1]
            dt = self.times[1] - self.times[0]
            batch = self.sample(
                t0, t1, dt,
                self.noise_std * jax.random.normal(key, (self.cfg.dataset.batch_size, self.cfg.dataset.len_trajectory, self.cfg.dataset.dim)),
                as_score=self.cfg.steering.solver is conf.model.Solver.EULER_MARUYAMA,
                key=key_solver,
            ).ys[-1]
        elif isinstance(self.cfg.steering, conf.model.SourceParallelTempering):
            batch = self.source_parallel_tempering_sample(
                key,
                self.cfg.dataset.batch_size,
                self.cfg.steering.update_count,
                potential_fn=self.potential_fn,
                times=self.times,
                inverse_temperature=self.inverse_temperature,
                step_size_angle=self.step_size_angle,
            ).ys[-1]
        elif isinstance(self.cfg.steering, conf.model.FeynmannKac):
            batch = self.feynmann_kac_steering_sample(
                key,
                self.cfg.dataset.batch_size,
                reward_fn=self.feynmann_kac_intermediate_reward_fn,
                times=self.times,
            )
        else:
            raise ValueError(f'Unknown steering: {self.cfg.steering}')
        self._predict_state = dict(batch=batch, event_constraint=self.event_constraint)

    @partial(jax.jit, static_argnames=['self', 'train'])
    def velocity(self, x, t, params, train=False):
        if (
            isinstance(self.cfg.get_model().probability_path, conf.model.VarianceExploding)
            and
            self.cfg.get_model().probability_path.finzi_karras_weighting
        ):
            # scaling is equivalent to that in Karras et al. https://arxiv.org/abs/2206.00364
            sigma, scale = self.diffusion.sigma(1 - t), self.diffusion.scale(1 - t)
            # Karras et al. $c_in$ and $s(t)$ of EDM.
            input_scale = 1 / jnp.sqrt(sigma**2 + (scale * self.train_data_std)**2)
            out = self.model.apply(params, x=x * input_scale, t=t.squeeze((1, 2)), train=train)
            # Karras et al. the demonimator of $c_out$ of EDM; where is the numerator?
            return out / jnp.sqrt(sigma**2 + (scale * self.train_data_std)**2)
        else:
            return self.model.apply(params, x=x, t=t.squeeze((1, 2)), train=train)

    @partial(jax.jit, static_argnames=['self'])
    def conditional_ot(self, t, x_noise, x_data):
        mean_scale, std = t, 1 - t
        xt = std * x_noise + mean_scale * x_data
        velocity_target = x_data - x_noise
        eps = 1e-6
        return dict(
            xt=xt,
            mean_scale=mean_scale, std=std,
            velocity_target=velocity_target,
            dx_velocity_target=-1 / (std + eps),
            dx_log_pt=-(xt - mean_scale * x_data) / (std + eps)**2,
        )

    @partial(jax.jit, static_argnames=['self'])
    def variance_exploding_conditional(self, t, x_noise, x_data):
        mean_scale, std = jnp.ones_like(t), self.diffusion.sigma(1 - t)
        eps = 1e-6
        # add eps here to make equal to divisor in velocity_target
        xt = x_data + (std + eps) * x_noise
        dt_std = self.diffusion.dsigma(1 - t)
        dx_velocity_target = -dt_std / (std + eps)
        velocity_target = dx_velocity_target * (xt - x_data)
        return dict(
            xt=xt,
            mean_scale=1., std=std,
            velocity_target=velocity_target,
            dx_velocity_target=dx_velocity_target,
            dx_log_pt=-(xt - mean_scale * x_data) / (std + eps)**2,
            dt_std=dt_std,
        )

    @partial(jax.jit, static_argnames=['self'])
    def loss(self, key, x_data, params):
        (
            key_time,
            key_noise,
        ) = jax.random.split(key, 2)
        if isinstance(self.cfg.get_model().probability_path, conf.model.VarianceExploding):
            u0 = jax.random.uniform(key_time)
            u = jnp.remainder(u0 + jnp.linspace(0, 1, x_data.shape[0]), 1)
            t = u * (1 - self.diffusion.tmin) + self.diffusion.tmin
            t = t[:, None, None]
        else:
            t = jax.random.uniform(key_time, shape=(x_data.shape[0], 1, 1))

        x_noise = jax.random.normal(key_noise, x_data.shape)

        if isinstance(self.cfg.get_model().probability_path, conf.model.ConditionalOT):
            context = self.conditional_ot(t, x_noise, x_data)
            weighting = 1.
        elif isinstance(self.cfg.get_model().probability_path, conf.model.VarianceExploding):
            context = self.variance_exploding_conditional(t, x_noise, x_data)
            weighting = 1 / context['dt_std']**2
        else:
            raise ValueError(f'Unknown conditional flow: {self.cfg.get_model().conditional_flow}')

        velocity_pred = self.velocity(context['xt'], t, params, train=True)
        flow_loss = ((velocity_pred - context['velocity_target'])**2 * weighting).mean()

        monitors = {'flow_loss': flow_loss}

        return flow_loss, monitors

    @partial(jax.jit, static_argnames=['self'])
    def step(self, key, batch, params, params_ema, opt_state):
        (loss, monitors), grads = self.loss_and_grad(key, batch, params)
        updates, opt_state = self.optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        ema_update = lambda p, ema: ema + (p - ema) / self.ema_ts
        params_ema = jax.tree.map(ema_update, params, params_ema)
        return loss, monitors, params, params_ema, opt_state

    @partial(jax.jit, static_argnames=['self', 'as_score'])
    def sample(self, t0, t1, dt, y0, as_score=False, key=None, params=None, saveat=diffrax.SaveAt(t1=True)):
        if params is None:
            params = self.params_ema
        if as_score:
            if not isinstance(self.cfg.get_model().probability_path, conf.model.VarianceExploding):
                raise ValueError(
                    f'Sampling using the score function is only implemented for a VarianceExploding probability path, not {self.cfg.get_model().probability_path.__class__.__name__}. '
                    'Please choose a different probability path or set as_score=False.'
                )
            if key is None:
                raise ValueError('Sampling using the score function requires a PRNG key. Please pass a key as a keyword argument.')
            def g2_score(t, y, args):
                if not hasattr(t, 'shape') or not t.shape:
                    t = jnp.ones((y0.shape[0], 1, 1)) * t
                return 2 * self.velocity(y, t, params)
            def diffusion(t, y, args):
                return lineax.DiagonalLinearOperator(self.diffusion.diffusion(None, None, 1 - t))
            return diffrax.diffeqsolve(
                diffrax.MultiTerm(
                    diffrax.ODETerm(g2_score),
                    diffrax.ControlTerm(
                        diffusion,
                        diffrax.VirtualBrownianTree(t0, t1, tol=1e-3, shape=y0.shape, key=key)
                    ),
                ),
                solver=diffrax.Euler(),
                t0=t0, dt0=dt, t1=t1,
                y0=y0,
                saveat=saveat,
            )
        else:
            def velocity(t, y, args):
                if not hasattr(t, 'shape') or not t.shape:
                    t = jnp.ones((y0.shape[0], 1, 1)) * t
                return self.velocity(y, t, params)
            return diffrax.diffeqsolve(
                diffrax.ODETerm(vector_field=velocity),
                diffrax.Heun(),
                t0=t0, dt0=dt, t1=t1,
                y0=y0,
                saveat=saveat,
            )

    def get_sample_times(self, sample_step_count):
        time_min = 0.
        if isinstance(self.cfg.get_model().probability_path, conf.model.VarianceExploding):
            time_min = self.cfg.get_model().probability_path.time_min
        return (time_min + jnp.arange(sample_step_count)) / sample_step_count

    def get_inverse_temperature(self, tilt, chain_count):
        return jnp.linspace(0, tilt, chain_count)

    def get_step_size_angle(self, chain_count, eps=1e-3, angle_min=.05):
        return jnp.linspace(jnp.pi / 2 - eps, angle_min, chain_count)

    def source_parallel_tempering_sample(self, key, batch_size, iter_count, potential_fn, times, inverse_temperature, step_size_angle, params=None):
        tempered_noise = self.ptmcmc.run(
            key,
        )[-1]
        t0 = times[0]
        t1 = times[-1]
        dt = times[1] - times[0]
        return self.sample(t0, t1, dt, tempered_noise)

    def feynmann_kac_steering_sample(self, key, batch_size, reward_fn, times, params=None):
        (
            key_initialize_ensemble,
            key_particle_resampling,
            key_reward_fn,
            key_propose_next_intermediate,
        ) = jax.random.split(key, 4)
        x_shape = (batch_size, self.cfg.steering.ensemble_size, self.cfg.dataset.len_trajectory, self.cfg.dataset.dim)
        batch = jax.random.normal(key_initialize_ensemble, x_shape) * self.noise_std
        t0 = times[0]
        t1 = times[-1]
        dt = times[1] - times[0]
        key_reward, key_reward_fn = jax.random.split(key_reward_fn)
        reward = reward_fn(jax.random.split(key_reward, self.cfg.steering.ensemble_size), t0, t1, dt, batch)
        if self.cfg.steering.potential in (conf.model.FeynmannKacPotential.DIFFERENCE, conf.model.FeynmannKacPotential.MAX):
            aggregated_reward = reward
        elif self.cfg.steering.potential is conf.model.FeynmannKacPotential.SUM:
            aggregated_reward = jnp.zeros((batch_size, self.cfg.steering.ensemble_size))
        else:
            raise ValueError(f'Aggregated reward initialization not implemented for {self.cfg.steering.potential.name}.')
        potential, aggregated_reward = self.feynmann_kac_potential_fn(reward, aggregated_reward)
        aggregated_potential = 1
        for time_step, (t0, t1) in enumerate(zip(times, times[1:])):
            aggregated_potential = aggregated_potential / potential
            key_resample, key_particle_resampling = jax.random.split(key_particle_resampling)
            batch = self.resampler(
                jax.random.split(key_resample, batch_size),
                batch,
                potential,
            )
            key_propose, key_propose_next_intermediate = jax.random.split(key_propose_next_intermediate)
            batch = self.propose_fn(jax.random.split(key_propose, self.cfg.steering.ensemble_size), t0, t1, dt, batch)
            key_reward, key_reward_fn = jax.random.split(key_reward_fn)
            reward = reward_fn(jax.random.split(key_reward, self.cfg.steering.ensemble_size), t0, t1, dt, batch)
            if t1 == times[-1] and self.cfg.steering.potential in (conf.model.FeynmannKacPotential.MAX, conf.model.FeynmannKacPotential.SUM):
                potential = jnp.exp(self.cfg.steering.tilt * reward)
            else:
                potential, aggregated_reward = self.feynmann_kac_potential_fn(reward, aggregated_reward)
        if self.cfg.steering.potential in (conf.model.FeynmannKacPotential.MAX, conf.model.FeynmannKacPotential.SUM):
            final_potential = potential / aggregated_potential
        elif self.cfg.steering.potential is conf.model.FeynmannKacPotential.DIFFERENCE:
            final_potential = potential
        else:
            raise ValueError(f'Final potential computation not implemented for {self.cfg.steering.potential.name}.')
        return jax.vmap(lambda b, i: b[i])(batch, jnp.argmax(final_potential, axis=1))

    def get_feynmann_kac_potential_fn(self):
        if self.cfg.steering.potential is conf.model.FeynmannKacPotential.DIFFERENCE:
            def potential(new_reward, aggregated_reward):
                new_potential = jnp.exp(self.cfg.steering.tilt * (new_reward - aggregated_reward))
                new_aggregated_reward = new_reward
                return new_potential, new_aggregated_reward
        elif self.cfg.steering.potential is conf.model.FeynmannKacPotential.MAX:
            def potential(new_reward, aggregated_reward):
                new_aggregated_reward = jnp.maximum(new_reward, aggregated_reward)
                new_potential = jnp.exp(self.cfg.steering.tilt * new_aggregated_reward)
                return new_potential, new_aggregated_reward
        elif self.cfg.steering.potential is conf.model.FeynmannKacPotential.SUM:
            def potential(new_reward, aggregated_reward):
                new_aggregated_reward = new_reward + aggregated_reward
                new_potential = jnp.exp(self.cfg.steering.tilt * new_aggregated_reward)
                return new_potential, new_aggregated_reward
        else:
            raise ValueError(f'Unknown Feynmann-Kac steering potential: {self.cfg.steering.potential}')
        return potential

    def get_feynmann_kac_intermediate_reward_fn(self, fn, params=None):
        if params is None:
            params = self.params_ema
        if isinstance(self.cfg.steering.intermediate_reward, conf.model.FeynmannKacIntermediateRewardExpectedSample):
            def reward_fn(key, t0, t1, dt, batch):
                t0 = t0 * jnp.ones((batch.shape[0], 1, 1))
                return fn(batch + 2 * self.velocity(batch, t0, params))
        elif isinstance(self.cfg.steering.intermediate_reward, conf.model.FeynmannKacIntermediateRewardSubEnsemblePushforward):
            def reward_fn(key, t0, t1, dt, batch):
                key_sample_sub_ensemble = jax.random.split(key, self.cfg.steering.intermediate_reward.sub_ensemble_size)
                x_shape = (batch.shape[0], self.cfg.steering.intermediate_reward.sub_ensemble_size, self.cfg.dataset.len_trajectory, self.cfg.dataset.dim)
                batch = jnp.broadcast_to(batch[:, None], x_shape)
                def get_sub_ensemble_member_reward(key, y0):
                    return fn(self.sample(t0, t1, dt, y0, as_score=True, key=key).ys[-1])
                sub_ensemble_member_reward = jax.vmap(get_sub_ensemble_member_reward, in_axes=(0, 1), out_axes=1)(key_sample_sub_ensemble, batch)
                return jnp.log(jnp.exp(sub_ensemble_member_reward).mean(axis=1))
        else:
            raise ValueError(f'Unknown Feynmann-Kac steering intermediate reward: {self.cfg.steering.intermediate_reward}')
        return reward_fn


@hydra.main(**userfm.utils.HYDRA_INIT)
def main(cfg: conf.conf.TrainingAndEvaluation):
    with conf.conf.Session() as db:
        cfg = conf.conf.orm.instantiate_and_insert_config(db, OmegaConf.to_container(cfg, resolve=True))
        db.commit()
        log.info('Command: python %s', ' '.join(sys.argv[:-1]))
        log.info(pprint.pformat(cfg))
        log.info('Output directory: %s', cfg.run_dir)

        log.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
        log.info('JAX devices: %r', jax.devices())

        splits = ('train', 'val', 'test')
        (
            *key_dataset_splits,
            key_jax_lightning,
        ) = jax.random.split(
            jax.random.key(userfm.utils.RNG_RANDBITS[cfg.rng_seed]),
            len(splits) + 1
        )

        splits = {
            s: userfm.datasets.get_dataset(cfg.dataset, k, batch_count)
            for s, k in zip(splits, key_dataset_splits)
            if (batch_count := getattr(cfg.dataset, f'batch_count_{s}')) > 0
        }
        if cfg.fit:
            train_data_std = splits['train'][:].std()
            np.save(cfg.run_dir/'train_data_std.npy', np.array(train_data_std))
        else:
            train_data_std = np.load(cfg.model.conf.run_dir/'train_data_std.npy')
        log.info('Training set standard deviation: %(data_std).7f', dict(data_std=train_data_std))
        dataloaders = {
            split: torch.utils.data.DataLoader(
                v,
                batch_sampler=torch.utils.data.BatchSampler(
                    torch.utils.data.SequentialSampler(v),
                    cfg.dataset.batch_size, drop_last=True,
                ),
                collate_fn=lambda x: x,
            )
            for split, v in splits.items()
        }
        dataloaders['predict'] = dataloaders.get(cfg.predict_split)

        cfg_unet = userdiffusion.unet.unet_64_config(
            cfg.dataset.dim,
            base_channels=cfg.get_model().base_channel_count,
            attention=cfg.get_model().use_attention,
        )
        model = userdiffusion.unet.UNet(cfg_unet)
        jax_lightning = JaxLightning(cfg, key_jax_lightning, dataloaders, train_data_std, model)
        if isinstance(cfg.model, conf.conf.Trained):
            orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
            ckpt_path = cfg.model.conf.run_dir/cfg.model.ckpt_filename
            jax_lightning.params = orbax_checkpointer.restore(ckpt_path)
            jax_lightning.params_ema = orbax_checkpointer.restore(ckpt_path.parent/f'{ckpt_path.stem}_ema{ckpt_path.suffix}')

        logger = userfm.loggers.CSVLogger(cfg.run_dir, name=None)

        pl_trainer = pl.Trainer(
            max_epochs=cfg.get_model().epoch_count,
            logger=logger,
            precision=32,
            callbacks=[
                userfm.callbacks.ModelCheckpoint(
                    dirpath=cfg.run_dir,
                    filename='{epoch}',
                    # save_top_k=1,
                    # monitor='train_loss_ema',
                    save_last='link',
                    every_n_epochs=100,
                    save_on_train_epoch_end=False,
                    enable_version_counter=False,
                ),
                userfm.callbacks.LogStats() if cfg.fit else userfm.callbacks.LogPredictStats(cfg.log_constraint_values),
            ],
            log_every_n_steps=1,
            check_val_every_n_epoch=1 if cfg.predict else 50,
            deterministic=True,
        )

        if cfg.fit:
            pl_trainer.fit(jax_lightning)
        elif cfg.predict:
            # use validation instead because it enables logging
            pl_trainer.validate(jax_lightning)


if __name__ == '__main__':
    last_override, run_dir = conf.conf.get_run_dir()
    conf.conf.set_run_dir(last_override, run_dir)
    main()
