from functools import partial
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import argparse
from tqdm import tqdm

import wandb

import numpy as np

import jax
import optax

import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.functional as pxf
import pcax.utils as pxu


from utils_pcax.data import get_datax_cnn, make_mean_images
from utils_pcax.utils import set_seed, sgdld, sgd_scaled
from utils_pcax.models import (
    AddLatent,
    CNNModel,
    Model,
    energy_per_stream,
    energy_weights,
    initialisation,
    energy,
    energy_down,
    energy_up,
    setup_infer_on_batch,
)
from utils_pcax.eval import (
    eval_acc,
    eval_latent_decoding_acc,
    eval_latent_reconstruction,
    eval_latent_reconstruction_free_latent,
    eval_rmse,
    eval_fid_inception,
)

"""
    This script can be used to train convolutional models with the following architectures:
        - bPC
        - uPC == discPC
        - genPC == dPC
        - hybridPC
    These models (most of them) can by train in supervised, unsupervised or combined supervised/unsupervised mode on CIFAR10 and CIFAR-100.
"""


def setup_train(
    local_infer_on_batch, energy_weights, is_free_latents=False, optim_w_latent=None
):
    @pxf.jit(static_argnums=(0, 3))
    def train_on_batch(
        T: int,
        x: jax.Array,
        y: jax.Array,
        is_up_initialisation: bool,
        *,
        model: Model,
        optim_w: pxu.Optim,
        optim_h: pxu.Optim,
    ):

        model.train()

        local_infer_on_batch(
            T, x, y, is_up_initialisation, 0, model=model, optim_h=optim_h
        )

        # Learning step
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True
            )(energy_weights)(x, y, model=model)
        optim_w.step(model, g["model"])
        return e

    @pxf.jit(static_argnums=(0, 3))
    def train_on_batch_free(
        T: int,
        x: jax.Array,
        y: jax.Array,
        is_up_initialisation: bool,
        *,
        model: Model,
        optim_w: pxu.Optim,
        optim_h: pxu.Optim,
        optim_w_latent: pxu.Optim,
    ):

        model.train()

        local_infer_on_batch(
            T, x, y, is_up_initialisation, 0, model=model, optim_h=optim_h
        )

        # Learning step
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True
            )(energy_weights)(x, y, model=model)
        optim_w.step(
            model,
            pxu.Mask(pxu.m(pxnn.LayerParam).has_not(latent=True).has_not(frozen=True))(
                g["model"]
            ),
            True,
        )
        optim_w_latent.step(
            model, pxu.Mask(pxu.m(pxnn.LayerParam).has(latent=True))(g["model"]), True
        )
        return e

    local_train = (
        partial(train_on_batch_free, optim_w_latent=optim_w_latent)
        if is_free_latents
        else train_on_batch
    )

    def train(
        dl,
        T,
        is_up_initialisation,
        *,
        model: Model,
        optim_w: pxu.Optim,
        optim_h: pxu.Optim,
        verbose: bool = False,
    ):
        for x, y in dl:
            e = local_train(
                T,
                x,
                y,
                is_up_initialisation,
                model=model,
                optim_w=optim_w,
                optim_h=optim_h,
            )
        return energy_per_stream(x, y, model=model)

    return train


def cosine_scedule(lr_p, nm_epochs, train_dl):
    if nm_epochs == 0:
        return optax.constant_schedule(lr_p)
    else:
        return optax.warmup_cosine_decay_schedule(
            init_value=lr_p,
            peak_value=1.1 * lr_p,
            warmup_steps=0.1 * len(train_dl) * nm_epochs,
            decay_steps=len(train_dl) * nm_epochs,
            end_value=0.1 * lr_p,
            exponent=1.0,
        )


def main(args):
    if args.is_wandb:
        wandb.init(project="test")
        for key, value in wandb.config.items():
            setattr(args, key, value)
        wandb.config.update(args)

    seed = args.seed
    set_seed(seed)

    verbose = args.verbose
    is_wandb = args.is_wandb

    # check model type
    is_supervised = args.is_supervised
    is_hybrid = args.is_hybrid
    is_free_latents = args.is_free_latents

    nm_epochs = args.nm_epochs
    epochs_per_val = args.epochs_per_val

    T = args.T
    T_eval = args.T_eval

    batch_size = args.batch_size
    lr = args.lr_x
    momentum = args.momentum
    activity_decay = args.activity_decay
    gamma = 0
    h_var = args.h_var
    activity_init = args.activity_init
    activity_init_kwargs = {"layer_var": h_var}
    is_up_initialisation = args.is_up_initialisation_default

    lr_p = args.lr_p
    weight_decay = args.weight_decay

    model_name = args.model_name
    input_var = args.input_var
    activation = args.activation_fn
    latent_dim = args.latent_dim
    data_dim = args.data_dim
    alpha_up = args.alpha_up
    alpha_down = args.alpha_down

    dataset = args.dataset
    input_channels = 1 if dataset.split("_")[-1] == "mnist" else 3

    model_latent_dim = args.free_latent_dim
    latent_init = args.activity_init

    # some checks
    assert alpha_up + alpha_down > 0.0, "At least one of the energies should be active"
    if is_hybrid:
        assert (
            alpha_up == 1.0 and alpha_down == 1.0
        ), "Hybrid model only works with alpha_up=1 and alpha_down=1"
    if is_free_latents:
        assert alpha_down > 0.0, "Free latents only work with alpha_down > 0"

    model = CNNModel(
        input_size=data_dim,
        output_size=latent_dim,
        input_channels=input_channels,
        cnn_name=model_name,
        activation=activation,
        input_var=input_var,
        alpha_up=alpha_up,
        alpha_down=alpha_down,
        is_supervised=is_supervised,
        activity_init=activity_init,
        activity_init_kwargs=activity_init_kwargs,
    )
    if args.load_path is not None:
        pxu.load_params(model, args.load_path)
        # fix all the up layers
    if is_free_latents:
        model = AddLatent(
            model, model_latent_dim, latent_init, latent_var=alpha_up / alpha_down
        )
        model.latent_vode.h.latent = True
        model.latent_layer_down.nn.weight.latent = True
        model.latent_layer_down.nn.bias.latent = True
        model.latent_layer_up.nn.weight.latent = True
        model.latent_layer_up.nn.bias.latent = True
        for l in model.down:
            if isinstance(l, pxnn.Layer):
                if hasattr(l.nn, "weight"):
                    l.nn.weight.latent = True
                    l.nn.bias.latent = True

    train_dl, val_dl, test_dl = get_datax_cnn(args)

    # Create a copy of args with is_supervised overridden to True
    if not is_supervised:
        args_supervised = argparse.Namespace(**vars(args))
        args_supervised.is_supervised = True
        _, val_dl_labelled, test_dl_labelled = get_datax_cnn(args_supervised)

    weight_schedule = cosine_scedule(lr_p, nm_epochs, train_dl)

    optim_h = pxu.Optim(optax.sgd(lr, momentum=momentum))
    optim_h_eval = pxu.Optim(
        sgdld(lr, momentum, h_var, gamma, activity_decay=activity_decay)
    )
    optim_h_eval_no_noise = pxu.Optim(
        sgdld(lr, momentum, 0.0, gamma, activity_decay=activity_decay)
    )
    optim_h_latent = (
        None
        if not is_free_latents
        else pxu.Optim(sgd_scaled(args.lr_x_latent, momentum, 1.0 / alpha_down))
    )  # order of tests - 0.003
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        initialisation(
            jax.numpy.zeros((batch_size, latent_dim)),
            jax.numpy.zeros((batch_size, input_channels, *data_dim)),
            model=model,
            is_up_initialisation=is_up_initialisation,
        )
        optim_w = pxu.Optim(
            optax.adamw(weight_schedule, weight_decay=weight_decay),
            pxu.Mask(pxu.m(pxnn.LayerParam).has_not(latent=True).has_not(frozen=True))(
                model
            ),
        )
        optim_w_latent = (
            None
            if not is_free_latents
            else pxu.Optim(
                optax.adamw(
                    cosine_scedule(args.lr_p_latent, nm_epochs, train_dl),
                    weight_decay=weight_decay,
                ),
                pxu.Mask(pxu.m(pxnn.LayerParam).has(latent=True))(model),
            )
        )

    # setup energy function
    if alpha_up == 0.0:
        inference_energy = energy_down
        weights_energy = energy_down
    elif alpha_down == 0.0:
        inference_energy = energy_up
        weights_energy = energy_up
    elif is_hybrid:
        inference_energy = energy_down
        weights_energy = energy
    elif is_free_latents:
        inference_energy = energy
        weights_energy = energy_weights
    else:
        inference_energy = energy
        weights_energy = energy
    local_infer_on_batch = setup_infer_on_batch(
        inference_energy, is_free_latents=is_free_latents
    )
    if is_free_latents:
        local_infer_on_batch = partial(
            local_infer_on_batch, optim_h_latent=optim_h_latent
        )
    train = setup_train(
        local_infer_on_batch,
        weights_energy,
        is_free_latents=is_free_latents,
        optim_w_latent=optim_w_latent,
    )

    if args.make_mean_image and is_supervised:
        make_mean_images(args, mode="val", verbose=verbose)
        make_mean_images(args, mode="test", verbose=verbose)

    nm_img_report = 10
    e0 = None  # used to assess divergence
    epoch_dl = tqdm(np.arange(-1, nm_epochs)) if verbose else np.arange(-1, nm_epochs)
    for e in epoch_dl:
        e_up, e_down = (
            train(
                train_dl,
                T=T,
                is_up_initialisation=is_up_initialisation,
                model=model,
                optim_w=optim_w,
                optim_h=optim_h,
                verbose=False,
            )
            if e > -1
            else (0.0, 0.0)
        )

        if verbose:
            epoch_dl.set_description_str(
                f"Epochs {e+1}/{nm_epochs}: E up {e_up/(alpha_up+1e-6):.3f}, E down {e_down/(alpha_down+1e-6):.3f}"
            )
        if e % epochs_per_val == epochs_per_val - 1:
            res = {}
            if is_supervised:
                res["val_acc_ff"] = eval_acc(
                    val_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_rmse_ff"], imgs_ff = eval_rmse(
                    model,
                    batch_size,
                    dataset,
                    "val",
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )

                res["val_combined_err_ff"] = (1 - res["val_acc_ff"]) * 2 + res[
                    "val_rmse_ff"
                ]

                res["val_acc_inf"] = eval_acc(
                    val_dl,
                    model,
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval_no_noise,
                        "is_up_initialisation": True,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=verbose,
                )
                res["val_rmse_inf"], imgs_inf = eval_rmse(
                    model,
                    batch_size,
                    dataset,
                    "val",
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval_no_noise,
                        "is_up_initialisation": False,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=verbose,
                )

                res["val_combined_err_inf"] = (1 - res["val_acc_inf"]) * 2 + res[
                    "val_rmse_inf"
                ]

                res["img_ff"] = (
                    [wandb.Image(np.array(image), mode="L") for image in imgs_ff]
                    if dataset.split("_")[-1] == "mnist"
                    else [
                        wandb.Image(
                            np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                        )
                        for image in imgs_ff
                    ]
                )
                res["img_inf"] = (
                    [wandb.Image(np.array(image), mode="L") for image in imgs_inf]
                    if dataset.split("_")[-1] == "mnist"
                    else [
                        wandb.Image(
                            np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                        )
                        for image in imgs_inf
                    ]
                )
                if is_free_latents:
                    res["val_recon_mse_ff"], recon_ff = (
                        eval_latent_reconstruction_free_latent(
                            val_dl,
                            model,
                            mode="fp",
                            mode_kwargs={"noise_var": 0.0},
                            verbose=verbose,
                        )
                    )
                    res["val_recon_mse_inf"], recon_inf = (
                        eval_latent_reconstruction_free_latent(
                            val_dl,
                            model,
                            mode="inference",
                            mode_kwargs={
                                "T": T_eval,
                                "optim_h": optim_h_eval_no_noise,
                                "is_up_initialisation": is_up_initialisation,
                                "infer_on_batch": local_infer_on_batch,
                            },
                            verbose=verbose,
                        )
                    )
                    res["val_recon_ff"] = (
                        [
                            wandb.Image(np.array(image), mode="L")
                            for image in recon_ff[:nm_img_report]
                        ]
                        if dataset.split("_")[-1] == "mnist"
                        else [
                            wandb.Image(
                                np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                            )
                            for image in recon_ff[:nm_img_report]
                        ]
                    )
                    res["val_recon_inf"] = (
                        [
                            wandb.Image(np.array(image), mode="L")
                            for image in recon_inf[:nm_img_report]
                        ]
                        if dataset.split("_")[-1] == "mnist"
                        else [
                            wandb.Image(
                                np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                            )
                            for image in recon_inf[:nm_img_report]
                        ]
                    )
                    res["val_combined_err_inf_free"] = (1 - res["val_acc_inf"]) + res[
                        "val_recon_mse_inf"
                    ] * 4
            else:
                res["val_acc_ff"] = eval_latent_decoding_acc(
                    val_dl, val_dl_labelled, model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_acc_inf"] = eval_latent_decoding_acc(
                    val_dl,
                    val_dl_labelled,
                    model,
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval_no_noise,
                        "is_up_initialisation": is_up_initialisation,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=verbose,
                )
                res["val_recon_mse_ff"], recon_ff = eval_latent_reconstruction(
                    val_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_recon_mse_inf"], recon_inf = eval_latent_reconstruction(
                    val_dl,
                    model,
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval_no_noise,
                        "is_up_initialisation": is_up_initialisation,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=verbose,
                )
                res["val_recon_ff"] = (
                    [
                        wandb.Image(np.array(image), mode="L")
                        for image in recon_ff[:nm_img_report]
                    ]
                    if dataset.split("_")[-1] == "mnist"
                    else [
                        wandb.Image(
                            np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                        )
                        for image in recon_ff[:nm_img_report]
                    ]
                )
                res["val_recon_inf"] = (
                    [
                        wandb.Image(np.array(image), mode="L")
                        for image in recon_inf[:nm_img_report]
                    ]
                    if dataset.split("_")[-1] == "mnist"
                    else [
                        wandb.Image(
                            np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                        )
                        for image in recon_inf[:nm_img_report]
                    ]
                )
            if not is_supervised or is_free_latents:
                if e == -1:
                    _, y = next(iter(val_dl))
                    res["val_base_imgs"] = (
                        [
                            wandb.Image(np.array(image), mode="L")
                            for image in y[:nm_img_report]
                        ]
                        if dataset.split("_")[-1] == "mnist"
                        else [
                            wandb.Image(
                                np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                            )
                            for image in y[:nm_img_report]
                        ]
                    )

            if h_var > 0.0:
                res["is_ff"], res["val_fid_ff"], imgs_ff = eval_fid_inception(
                    dataset,
                    val_dl,
                    1000,
                    model,
                    latent_dim,
                    batch_size,
                    mode="fp",
                    mode_kwargs={"noise_var": h_var},
                    verbose=verbose,
                    subset="val",
                    labels=is_supervised,
                )
                res["is_inf"], res["val_fid_inf"], imgs_inf = eval_fid_inception(
                    dataset,
                    val_dl,
                    1000,
                    model,
                    latent_dim,
                    batch_size,
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval,
                        "is_up_initialisation": False,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=verbose,
                    subset="val",
                    labels=is_supervised,
                )
                res["img_sto_ff"] = (
                    [
                        wandb.Image(np.array(image), mode="L")
                        for image in imgs_ff[:nm_img_report]
                    ]
                    if dataset.split("_")[-1] == "mnist"
                    else [
                        wandb.Image(
                            np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                        )
                        for image in imgs_ff[:nm_img_report]
                    ]
                )
                res["img_sto_inf"] = (
                    [
                        wandb.Image(np.array(image), mode="L")
                        for image in imgs_inf[:nm_img_report]
                    ]
                    if dataset.split("_")[-1] == "mnist"
                    else [
                        wandb.Image(
                            np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                        )
                        for image in imgs_inf[:nm_img_report]
                    ]
                )
                if is_supervised:
                    res["var_combined_sto"] = res["val_fid_inf"] / 100 + (
                        1 - res["val_acc_inf"]
                    )
            if is_wandb:
                wandb.log(res)

            # exit run if diverging energy or if reconstructed images are not good
            if e > 0:
                is_recon_diverging = (not is_supervised or is_free_latents) and (
                    res["val_recon_mse_inf"] > 0.25
                    or res["val_recon_mse_ff"] > 0.50
                    or res["val_recon_mse_inf"] != res["val_recon_mse_inf"]
                )
                is_acc_diverging = is_supervised and res["val_acc_inf"] < 0.10
                if is_recon_diverging or is_acc_diverging:
                    if is_wandb:
                        wandb.finish()
                    exit()

    # get final results
    # print("reporting test results")
    res = {}
    if is_supervised:
        res["test_acc_ff"] = eval_acc(test_dl, model, verbose=verbose)
        res["test_rmse_ff"], imgs_ff = eval_rmse(
            model, batch_size, dataset, "test", verbose=verbose
        )

        res["test_acc_inf"] = eval_acc(
            test_dl,
            model,
            mode="inference",
            mode_kwargs={
                "T": T_eval,
                "optim_h": optim_h_eval_no_noise,
                "is_up_initialisation": True,
                "infer_on_batch": local_infer_on_batch,
            },
            verbose=verbose,
        )
        res["test_rmse_inf"], imgs_inf = eval_rmse(
            model,
            batch_size,
            dataset,
            "test",
            mode="inference",
            mode_kwargs={
                "T": T_eval,
                "optim_h": optim_h_eval_no_noise,
                "is_up_initialisation": False,
                "infer_on_batch": local_infer_on_batch,
            },
            verbose=verbose,
        )

        res["img_ff"] = (
            [wandb.Image(np.array(image), mode="L") for image in imgs_ff]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in imgs_ff
            ]
        )
        res["img_inf"] = (
            [wandb.Image(np.array(image), mode="L") for image in imgs_inf]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in imgs_inf
            ]
        )
        if is_free_latents:
            res["test_recon_mse_ff"], recon_ff = eval_latent_reconstruction_free_latent(
                test_dl,
                model,
                mode="fp",
                mode_kwargs={"noise_var": 0.0},
                verbose=verbose,
            )
            res["test_recon_mse_inf"], recon_inf = (
                eval_latent_reconstruction_free_latent(
                    test_dl,
                    model,
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval_no_noise,
                        "is_up_initialisation": is_up_initialisation,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=verbose,
                )
            )
            res["test_recon_ff"] = (
                [wandb.Image(np.array(image), mode="L") for image in recon_ff]
                if dataset.split("_")[-1] == "mnist"
                else [
                    wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                    for image in recon_ff[:nm_img_report]
                ]
            )
            res["test_recon_inf"] = (
                [wandb.Image(np.array(image), mode="L") for image in recon_inf]
                if dataset.split("_")[-1] == "mnist"
                else [
                    wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                    for image in recon_inf[:nm_img_report]
                ]
            )
    else:
        res["test_acc_ff"] = eval_latent_decoding_acc(
            test_dl,
            test_dl_labelled,
            model,
            mode="fp",
            mode_kwargs={"noise_var": 0.0},
            verbose=verbose,
        )
        res["test_acc_inf"] = eval_latent_decoding_acc(
            test_dl,
            test_dl_labelled,
            model,
            mode="inference",
            mode_kwargs={
                "T": T_eval,
                "optim_h": optim_h_eval_no_noise,
                "is_up_initialisation": is_up_initialisation,
                "infer_on_batch": local_infer_on_batch,
            },
            verbose=verbose,
        )
        res["test_recon_mse_ff"], recon_ff = eval_latent_reconstruction(
            test_dl, model, mode="fp", mode_kwargs={"noise_var": 0.0}, verbose=verbose
        )
        res["test_recon_mse_inf"], recon_inf = eval_latent_reconstruction(
            test_dl,
            model,
            mode="inference",
            mode_kwargs={
                "T": T_eval,
                "optim_h": optim_h_eval_no_noise,
                "is_up_initialisation": is_up_initialisation,
                "infer_on_batch": local_infer_on_batch,
            },
            verbose=verbose,
        )
        res["recon_ff"] = (
            [wandb.Image(np.array(image), mode="L") for image in recon_ff]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in recon_ff[:nm_img_report]
            ]
        )
        res["recon_inf"] = (
            [wandb.Image(np.array(image), mode="L") for image in recon_inf]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in recon_inf[:nm_img_report]
            ]
        )
    if not is_supervised or is_free_latents:
        _, y = next(iter(test_dl))
        res["test_base_imgs"] = (
            [wandb.Image(np.array(image), mode="L") for image in y]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in y[:nm_img_report]
            ]
        )

    if h_var > 0.0:
        res["is_ff"], res["test_fid_ff"], imgs_ff = eval_fid_inception(
            dataset,
            test_dl,
            1000,
            model,
            latent_dim,
            batch_size,
            mode="fp",
            mode_kwargs={"noise_var": h_var},
            verbose=verbose,
            subset="test",
            labels=is_supervised,
        )
        res["is_inf"], res["test_fid_inf"], imgs_inf = eval_fid_inception(
            dataset,
            test_dl,
            1000,
            model,
            latent_dim,
            batch_size,
            mode="inference",
            mode_kwargs={
                "T": T_eval,
                "optim_h": optim_h_eval,
                "is_up_initialisation": False,
                "infer_on_batch": local_infer_on_batch,
            },
            verbose=verbose,
            subset="test",
            labels=is_supervised,
        )
        res["img_sto_inf"] = (
            [
                wandb.Image(np.array(image), mode="L")
                for image in imgs_inf[:nm_img_report]
            ]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in imgs_inf[:nm_img_report]
            ]
        )
        res["img_sto_ff"] = (
            [
                wandb.Image(np.array(image), mode="L")
                for image in imgs_ff[:nm_img_report]
            ]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in imgs_ff[:nm_img_report]
            ]
        )

    if is_wandb:
        wandb.log(res)
        wandb.finish()

    if args.save_path is not None:
        pxu.save_params(model, args.save_path)

    return model, optim_h, optim_w, is_up_initialisation


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Training a bidirectional CNN model.", fromfile_prefix_chars="@"
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="set seed for reproducibility"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10",
        choices=["mnist", "fashion_mnist", "cifar10", "cifar100"],
        help="dataset to use for training",
    )
    parser.add_argument(
        "--is-up-initialisation-default",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="initialisation of the model activity with the upward pass or downward pass",
    )
    parser.add_argument("--train-size", type=int, default=50000, help="training size")
    parser.add_argument("--val-size", type=int, default=5000, help="validation size")
    parser.add_argument("--test-size", type=int, default=5000, help="test size")
    parser.add_argument(
        "--batch-size", type=int, default=256, help="training batch size"
    )
    parser.add_argument(
        "--nm-epochs", type=int, default=50, help="number of epochs to train"
    )
    parser.add_argument(
        "--lr-x", type=float, default=0.001, help="learning rate of the latent state x"
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.0,
        help="learning rate of the latent state x",
    )
    parser.add_argument(
        "--T", type=int, default=32, help="number of inference iterations"
    )
    parser.add_argument(
        "--T-eval",
        type=int,
        default=100,
        help="number of inference iterations at evaluation time",
    )
    parser.add_argument(
        "--lr-p",
        type=float,
        default=0.0001,
        help="learning rate of the model parameters",
    )
    parser.add_argument(
        "--weight-decay",
        type=float,
        default=0.0001,
        help="weight decay of the model parameters",
    )
    parser.add_argument(
        "--activity-decay",
        type=float,
        default=0.0,
        help="weight decay of the model parameters",
    )
    parser.add_argument(
        "--data-dim", type=int, default=(32, 32), help="input size of the input layer"
    )
    parser.add_argument(
        "--activation-fn",
        type=str,
        default="l-relu",
        choices=["relu", "tanh", "l-relu", "gelu"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--is-wandb",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="log the results to wandb",
    )
    parser.add_argument(
        "--verbose",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="print the results to the console",
    )
    parser.add_argument(
        "--epochs-per-val",
        type=int,
        default=5,
        help="number of epochs between validation",
    )
    parser.add_argument(
        "--activity-init",
        type=str,
        default="ff",
        choices=["ff", "zero", "randn", "noisy-ff", "xavier"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--make-mean-image",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="recompute the mean image per class of dataset",
    )

    parser.add_argument(
        "--model-name",
        type=str,
        default="SmallVGG",  # choices=['SmallVGG','TestCNN','MonoVGG', 'MiniVGG', 'VGG11', 'VGG5', 'VGG5np']
        help="activation function of the hidden layers",
    )

    parser.add_argument(
        "--alpha-up",
        type=float,
        default=1.0,
        help="scaling factor for the upward model",
    )
    parser.add_argument(
        "--alpha-down",
        type=float,
        default=1.00000,
        help="scaling factor for the downward model",
    )

    parser.add_argument(
        "--latent-dim", type=int, default=10, help="output size of the output layer"
    )
    parser.add_argument(
        "--is-supervised",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="supervised or unsupervised learning",
    )
    parser.add_argument(
        "--is-hybrid",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="hybrid model with down energy only for inference",
    )
    parser.add_argument(
        "--h-var",
        type=float,
        default=0.0,
        help="Langevin noise variance of the (hidden) pc layers",
    )
    parser.add_argument(
        "--input-var", type=float, default=1.0, help="input variance of the input layer"
    )
    parser.add_argument(
        "--is-cnn",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="should always be true to be changed later",
    )

    parser.add_argument(
        "--is-free-latents",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="to be used with supervised model to combine supervised and unsupervised learning",
    )
    parser.add_argument(
        "--free-latent-dim", type=int, default=256, help="size of the free latent space"
    )
    parser.add_argument(
        "--lr-x-latent",
        type=float,
        default=0.01,
        help="learning rate of the latent state x",
    )
    parser.add_argument(
        "--lr-p-latent",
        type=float,
        default=0.001,
        help="learning rate of the model parameters",
    )

    parser.add_argument(
        "--load-path", type=str, default=None
    )  # "results/models/VGG_5PC_Base"
    parser.add_argument("--save-path", type=str, default=None)

    args = parser.parse_args()
    main(args)
