from functools import partial
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import argparse
from tqdm import tqdm
import random

import wandb

import numpy as np

import jax
import optax

import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.functional as pxf
import pcax.utils as pxu


from utils_pcax.data import get_datax_mlp, make_mean_images
from utils_pcax.utils import set_seed, sgdld
from utils_pcax.models import (
    AddLatent,
    Model,
    energy_down,
    energy_up,
    energy,
    energy_per_stream,
    initialisation,
    setup_infer_on_batch,
)
from utils_pcax.eval import (
    eval_acc,
    eval_latent_decoding_acc,
    eval_latent_reconstruction,
    eval_latent_reconstruction_free_latent,
    eval_rmse,
    eval_fid_inception,
)

"""
    This script can be used to train:
        - uBP == discBP
        - genBP
        - hybridBP
    These models (most of them) can by train in supervised, unsupervised or combined supervised/unsupervised mode on MNIST or Fashion MNIST.
"""


def setup_train(local_infer_on_batch, energy_weights):
    @pxf.jit(static_argnums=(0, 3))
    def train_on_batch(
        T: int,
        x: jax.Array,
        y: jax.Array,
        is_up_initialisation: bool,
        *,
        model: Model,
        optim_w: pxu.Optim,
        optim_h: pxu.Optim,
    ):

        model.train()

        local_infer_on_batch(
            T, x, y, is_up_initialisation, 0, model=model, optim_h=optim_h
        )

        # Learning step
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True
            )(energy_weights)(x, y, model=model)
        optim_w.step(model, g["model"])
        return e

    def train(
        dl,
        T,
        is_up_initialisation,
        *,
        model: Model,
        optim_w: pxu.Optim,
        optim_h: pxu.Optim,
        verbose: bool = False,
    ):
        for x, y in dl:
            e = train_on_batch(
                T,
                x,
                y,
                is_up_initialisation,
                model=model,
                optim_w=optim_w,
                optim_h=optim_h,
            )
        return energy_per_stream(dl[0][0], dl[0][1], model=model)

    return train


def main(args):
    if args.is_wandb:
        wandb.init(project="test")
        for key, value in wandb.config.items():
            setattr(args, key, value)
        wandb.config.update(args)

    seed = args.seed
    set_seed(seed)

    verbose = args.verbose
    is_wandb = args.is_wandb

    # check model type
    is_supervised = args.is_supervised
    is_hybrid = args.is_hybrid
    is_free_latents = args.is_free_latents

    nm_epochs = args.nm_epochs
    epochs_per_val = args.epochs_per_val

    T = (
        args.T
    )  # the inference process is kept so that the bias can be used to have an iterative process for unsupervised learning
    T_eval = args.T_eval

    batch_size = args.batch_size
    lr = args.lr_x
    momentum = args.momentum
    activity_decay = args.activity_decay
    gamma = 0
    h_var = args.h_var
    activity_init = args.activity_init
    activity_init_kwargs = {"layer_var": h_var}
    is_up_initialisation = args.is_up_initialisation_default

    lr_p = args.lr_p
    weight_decay = args.weight_decay

    is_shared_weights = args.is_shared_weights
    input_var = args.input_var
    activation = args.activation_fn
    latent_dim = args.latent_dim
    hidden_dim = args.hidden_dim
    data_dim = args.data_dim
    alpha_up = args.alpha_up
    alpha_down = args.alpha_down
    nm_layers = args.nm_layers

    dataset = args.dataset

    model_latent_dim = args.free_latent_dim
    latent_init = args.activity_init

    # some checks
    assert alpha_up + alpha_down > 0.0, "At least one of the energies should be active"
    if not is_supervised:
        assert (
            is_hybrid
        ), "BP unsupervised learning with amortised inference, AE not implemented"
        assert (
            alpha_up == 1.0 and alpha_down == 1.0
        ), "amortised inference preffered setup is with alpha_up=1 and alpha_down=1"
        assert is_up_initialisation
        assert T > 0, "T should be greater than 0 for unsupervised learning"
    if is_free_latents:
        assert (
            is_hybrid
        ), "Free latents update of free labels is done with down energy like amortised inference models"
        assert is_supervised, "Free latents are only used in supervised learning"
        assert alpha_down > 0.0
        assert alpha_up > 0.0
        assert T > 0
        # assert alpha_up == 1. and alpha_down == 1.

    model = Model(
        input_dim=latent_dim,
        hidden_dim=hidden_dim,
        output_dim=data_dim,
        nm_layers=nm_layers,
        activation=activation,
        input_var=input_var,
        alpha_up=alpha_up,
        alpha_down=alpha_down,
        is_supervised=is_supervised,
        is_shared_weights=is_shared_weights,
        activity_init=activity_init,
        activity_init_kwargs=activity_init_kwargs,
    )
    if is_free_latents:
        model = AddLatent(
            model,
            model_latent_dim,
            latent_init,
            latent_var=alpha_up / alpha_down,
            is_stop_gradient=False,
        )
        model.latent_vode.h.latent = True

    # convert pc to bp by removing the pc layers
    if is_supervised:
        # update model.down
        down_layers = [model.down[0]]
        for layer in model.down[1:-1]:
            if not isinstance(layer, pxc.Vode):
                down_layers.append(layer)
            else:  # needed for free latent model to know where to join model.down
                down_layers.append(px.static(lambda x: x))
        down_layers.append(model.down[-1])
        model.down = down_layers

        # update model.up
        up_layers = [model.up[0]]
        for layer in model.up[1:-1]:
            if not isinstance(layer, pxc.Vode):
                up_layers.append(layer)
            else:
                up_layers.append(px.static(lambda x: x))
        up_layers.append(model.up[-1])
        model.up = up_layers

        # clean vodes
        model.vodes = [model.vodes[0], model.vodes[-1]]
    else:
        # update model.down
        down_layers = [*model.down[:4]]
        for layer in model.down[4:]:
            if not isinstance(layer, pxc.Vode):
                down_layers.append(layer)
        down_layers.append(model.down[-1])
        model.down = down_layers

        # update model.up
        up_layers = [model.up[0]]
        for layer in model.up[1:-4]:
            if not isinstance(layer, pxc.Vode):
                up_layers.append(layer)
        up_layers += [*model.up[-4:]]
        model.up = up_layers

        # clean vodes
        model.vodes = [model.vodes[0], model.vodes[1], model.vodes[-1]]

    optim_h = pxu.Optim(sgdld(lr, momentum, h_var, gamma))
    optim_h_eval = pxu.Optim(
        sgdld(lr, momentum, h_var, gamma, activity_decay=activity_decay)
    )
    optim_h_eval_no_noise = pxu.Optim(
        sgdld(lr, momentum, 0.0, gamma, activity_decay=activity_decay)
    )
    optim_h_latent = (
        None
        if not is_free_latents
        else pxu.Optim(sgdld(lr * alpha_up / alpha_down, momentum, h_var, gamma))
    )
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        initialisation(
            jax.numpy.zeros((batch_size, latent_dim)),
            jax.numpy.zeros((batch_size, 784)),
            model=model,
            is_up_initialisation=is_up_initialisation,
        )
        optim_w = pxu.Optim(
            optax.adamw(lr_p, weight_decay=weight_decay),
            pxu.Mask(pxnn.LayerParam)(model),
        )

    train_dl, val_dl, test_dl = get_datax_mlp(args)

    # Create a copy of args with is_supervised overridden to True
    if not is_supervised:
        args_supervised = argparse.Namespace(**vars(args))
        args_supervised.is_supervised = True
        _, val_dl_labelled, test_dl_labelled = get_datax_mlp(args_supervised)

    # setup learning
    if alpha_up == 0.0:
        inference_energy = energy_down
        weights_energy = energy_down
    elif alpha_down == 0.0:
        inference_energy = energy_up
        weights_energy = energy_up
    elif is_hybrid:
        inference_energy = energy_down
        weights_energy = energy
    else:
        inference_energy = energy
        weights_energy = energy
    local_infer_on_batch = setup_infer_on_batch(
        inference_energy, is_free_latents=is_free_latents
    )
    if is_free_latents:
        local_infer_on_batch = partial(
            local_infer_on_batch, optim_h_latent=optim_h_latent
        )
    train = setup_train(local_infer_on_batch, weights_energy)

    if args.make_mean_image and is_supervised:
        make_mean_images(args, mode="val", verbose=verbose)
        make_mean_images(args, mode="test", verbose=verbose)

    nm_report = 30
    epoch_dl = tqdm(np.arange(-1, nm_epochs)) if verbose else np.arange(-1, nm_epochs)
    for e in epoch_dl:
        random.shuffle(train_dl)
        e_up, e_down = (
            train(
                train_dl,
                T=T,
                is_up_initialisation=is_up_initialisation,
                model=model,
                optim_w=optim_w,
                optim_h=optim_h,
                verbose=False,
            )
            if e > -1
            else (0.0, 0.0)
        )

        if verbose:
            epoch_dl.set_description_str(
                f"Epochs {e+1}/{nm_epochs}: E up {e_up/(alpha_up+1e-6):.3f}, E down {e_down/(alpha_down+1e-6):.3f}"
            )

        if e % epochs_per_val == epochs_per_val - 1:
            res = {}
            if is_supervised:
                res["val_acc_ff"] = eval_acc(
                    val_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_rmse_ff"], imgs_ff = eval_rmse(
                    model,
                    batch_size,
                    dataset,
                    "val",
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_combined_err_ff"] = (1 - res["val_acc_ff"]) * 2 + res[
                    "val_rmse_ff"
                ]
                res["img_ff"] = [
                    wandb.Image(np.array(image), mode="L") for image in imgs_ff
                ]
                if is_free_latents:
                    res["val_recon_mse_free_ff"], imgs_ff_free = (
                        eval_latent_reconstruction_free_latent(
                            val_dl,
                            model,
                            mode="fp",
                            mode_kwargs={"noise_var": 0.0},
                            verbose=True,
                        )
                    )
                    res["val_recon_mse_free_inf"], imgs_inf_free = (
                        eval_latent_reconstruction_free_latent(
                            val_dl,
                            model,
                            mode="inference",
                            mode_kwargs={
                                "T": T_eval,
                                "optim_h": optim_h_eval_no_noise,
                                "is_up_initialisation": is_up_initialisation,
                                "infer_on_batch": local_infer_on_batch,
                            },
                            verbose=True,
                        )
                    )
                    res["val_recon_ff_free"] = [
                        wandb.Image(np.array(image).reshape(28, 28), mode="L")
                        for image in imgs_ff_free[:nm_report]
                    ]
                    res["val_recon_inf_free"] = [
                        wandb.Image(np.array(image).reshape(28, 28), mode="L")
                        for image in imgs_inf_free[:nm_report]
                    ]
                    res["val_combined_err_inf_free"] = (
                        1 - res["val_acc_ff"]
                    ) * 2 + res["val_recon_mse_free_inf"]
            else:
                res["val_acc_ff"] = eval_latent_decoding_acc(
                    val_dl, val_dl_labelled, model, mode="fp", mode_kwargs={"noise_var": 0.0}, verbose=verbose
                )
                res["val_acc_inf"] = eval_latent_decoding_acc(
                    val_dl,
                    val_dl_labelled,
                    model,
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval_no_noise,
                        "is_up_initialisation": is_up_initialisation,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=verbose,
                )

                res["val_recon_mse_ff"], recon_ff = eval_latent_reconstruction(
                    val_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_recon_mse_inf"], recon_inf = eval_latent_reconstruction(
                    val_dl,
                    model,
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval_no_noise,
                        "is_up_initialisation": is_up_initialisation,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=True,
                )
                res["val_recon_ff"] = [
                    wandb.Image(np.array(image).reshape(28, 28), mode="L")
                    for image in recon_ff[:nm_report]
                ]
                res["val_recon_inf"] = [
                    wandb.Image(np.array(image).reshape(28, 28), mode="L")
                    for image in recon_inf[:nm_report]
                ]
            if not is_supervised or is_free_latents:
                if e == -1:
                    _, y = val_dl[0]
                    res["val_base_imgs"] = [
                        wandb.Image(np.array(image).reshape(28, 28), mode="L")
                        for image in y[:nm_report]
                    ]

            if h_var > 0.0:
                raise NotImplementedError(
                    "FID not implemented for unsupervised learning"
                )
                res["is_ff"], res["val_fid_ff"], imgs_ff = eval_fid_inception(
                    dataset,
                    val_dl,
                    1000,
                    model,
                    latent_dim,
                    batch_size,
                    mode="fp",
                    mode_kwargs={"noise_var": h_var},
                    verbose=verbose,
                    subset="val",
                    labels=is_supervised,
                )
                res["is_inf"], res["val_fid_inf"], imgs_inf = eval_fid_inception(
                    dataset,
                    val_dl,
                    1000,
                    model,
                    latent_dim,
                    batch_size,
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval,
                        "is_up_initialisation": False,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=verbose,
                    subset="val",
                    labels=is_supervised,
                )

                res["img_sto_ff"] = [
                    wandb.Image(np.array(image), mode="L")
                    for image in imgs_ff[:nm_report]
                ]
                res["img_sto_inf"] = [
                    wandb.Image(np.array(image), mode="L")
                    for image in imgs_inf[:nm_report]
                ]

                if is_supervised:
                    res["var_combined_sto"] = (
                        res["val_fid_inf"] / 100
                        + (1 - res["val_acc_inf"])
                        - res["is_inf"] / 100
                    )

            if is_wandb:
                wandb.log(res)

    # get final results
    # print("reporting test results")
    res = {}

    if is_supervised:
        res["test_acc_ff"] = eval_acc(test_dl, model, verbose=verbose)
        res["test_rmse_ff"], imgs_ff = eval_rmse(
            model, batch_size, dataset, "test", verbose=verbose
        )
        res["img_ff"] = [wandb.Image(np.array(image), mode="L") for image in imgs_ff]
        if is_free_latents:
            res["test_recon_mse_free_ff"], imgs_ff_free = (
                eval_latent_reconstruction_free_latent(
                    test_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=True,
                )
            )
            res["test_recon_mse_free_inf"], imgs_inf_free = (
                eval_latent_reconstruction_free_latent(
                    test_dl,
                    model,
                    mode="inference",
                    mode_kwargs={
                        "T": T_eval,
                        "optim_h": optim_h_eval_no_noise,
                        "is_up_initialisation": is_up_initialisation,
                        "infer_on_batch": local_infer_on_batch,
                    },
                    verbose=True,
                )
            )
            res["test_recon_ff_free"] = [
                wandb.Image(np.array(image).reshape(28, 28), mode="L")
                for image in imgs_ff_free[:nm_report]
            ]
            res["test_recon_inf_free"] = [
                wandb.Image(np.array(image).reshape(28, 28), mode="L")
                for image in imgs_inf_free[:nm_report]
            ]
    else:
        res["test_acc_ff"] = eval_latent_decoding_acc(
            test_dl,
            test_dl_labelled,
            model,
            mode="fp",
            mode_kwargs={"noise_var": 0.0},
            verbose=verbose,
        )
        res["test_acc_inf"] = eval_latent_decoding_acc(
            test_dl,
            test_dl_labelled,
            model,
            mode="inference",
            mode_kwargs={
                "T": T_eval,
                "optim_h": optim_h_eval_no_noise,
                "is_up_initialisation": is_up_initialisation,
                "infer_on_batch": local_infer_on_batch,
            },
            verbose=verbose,
        )

        res["test_recon_mse_ff"], recon_ff = eval_latent_reconstruction(
            test_dl, model, mode="fp", mode_kwargs={"noise_var": 0.0}, verbose=verbose
        )
        res["test_recon_mse_inf"], recon_inf = eval_latent_reconstruction(
            test_dl,
            model,
            mode="inference",
            mode_kwargs={
                "T": T_eval,
                "optim_h": optim_h_eval_no_noise,
                "is_up_initialisation": is_up_initialisation,
                "infer_on_batch": local_infer_on_batch,
            },
            verbose=True,
        )
        res["test_recon_ff"] = [
            wandb.Image(np.array(image).reshape(28, 28), mode="L")
            for image in recon_ff[:nm_report]
        ]
        res["test_recon_inf"] = [
            wandb.Image(np.array(image).reshape(28, 28), mode="L")
            for image in recon_inf[:nm_report]
        ]
    if not is_supervised or is_free_latents:
        _, y = test_dl[0]
        res["test_base_imgs"] = [
            wandb.Image(np.array(image).reshape(28, 28), mode="L")
            for image in y[:nm_report]
        ]

    if h_var > 0.0:
        raise NotImplementedError("FID not implemented for unsupervised learning")
        res["is_ff"], res["test_fid_ff"], imgs_ff = eval_fid_inception(
            dataset,
            test_dl,
            1000,
            model,
            latent_dim,
            batch_size,
            mode="fp",
            mode_kwargs={"noise_var": h_var},
            verbose=verbose,
            subset="test",
            labels=is_supervised,
        )
        res["is_inf"], res["test_fid_inf"], imgs_inf = eval_fid_inception(
            dataset,
            test_dl,
            1000,
            model,
            latent_dim,
            batch_size,
            mode="inference",
            mode_kwargs={
                "T": T_eval,
                "optim_h": optim_h_eval,
                "is_up_initialisation": False,
                "infer_on_batch": local_infer_on_batch,
            },
            verbose=verbose,
            subset="test",
            labels=is_supervised,
        )

        res["img_sto_ff"] = [
            wandb.Image(np.array(image), mode="L") for image in imgs_ff[:nm_report]
        ]
        res["img_sto_inf"] = [
            wandb.Image(np.array(image), mode="L") for image in imgs_inf[:nm_report]
        ]

    if is_wandb:
        wandb.log(res)
        wandb.finish()

    if is_supervised:
        pxu.save_params(model, "results/bp_small_mnist")


    return model, optim_h, optim_w, is_up_initialisation


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Training a bidirectional MLP model.", fromfile_prefix_chars="@"
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="set seed for reproducibility"
    )

    parser.add_argument(
        "--dataset",
        type=str,
        default="fashion_mnist",
        choices=["mnist", "fashion_mnist"],
        help="dataset to use for training",
    )
    parser.add_argument(
        "--is-up-initialisation-default",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="initialisation of the model activity with the upward pass or downward pass",
    )
    parser.add_argument("--train-size", type=int, default=10000, help="training size")
    parser.add_argument("--val-size", type=int, default=1000, help="validation size")
    parser.add_argument("--test-size", type=int, default=1000, help="test size")
    parser.add_argument(
        "--batch-size", type=int, default=256, help="training batch size"
    )
    parser.add_argument(
        "--nm-epochs", type=int, default=25, help="number of epochs to train"
    )
    parser.add_argument(
        "--lr-x", type=float, default=0.01, help="learning rate of the latent state x"
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.0,
        help="learning rate of the latent state x",
    )
    parser.add_argument(
        "--T", type=int, default=8, help="number of inference iterations"
    )
    parser.add_argument(
        "--T-eval",
        type=int,
        default=100,
        help="number of inference iterations at evaluation time",
    )
    parser.add_argument(
        "--lr-p",
        type=float,
        default=0.0002,
        help="learning rate of the model parameters",
    )
    parser.add_argument(
        "--weight-decay",
        type=float,
        default=0.0000,
        help="weight decay of the model parameters",
    )
    parser.add_argument(
        "--activity-decay",
        type=float,
        default=0.0,
        help="weight decay of the model parameters",
    )
    parser.add_argument(
        "--data-dim", type=int, default=784, help="input size of the input layer"
    )
    parser.add_argument(
        "--hidden-dim", type=int, default=256, help="hidden size of the hidden layers"
    )
    parser.add_argument(
        "--nm-layers", type=int, default=4, help="number of hidden layers"
    )
    parser.add_argument(
        "--activation-fn",
        type=str,
        default="tanh",
        choices=["relu", "tanh", "l-relu", "gelu"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--is-wandb",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="log the results to wandb",
    )
    parser.add_argument(
        "--verbose",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="print the results to the console",
    )
    parser.add_argument(
        "--epochs-per-val",
        type=int,
        default=5,
        help="number of epochs between validation",
    )
    parser.add_argument(
        "--is-shared-weights",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="use shared weights for the upward and downward models",
    )
    parser.add_argument(
        "--activity-init",
        type=str,
        default="ff",
        choices=["ff", "zero", "randn", "noisy-ff", "xavier"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--make-mean-image",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="recompute the mean image per class of dataset",
    )

    parser.add_argument(
        "--alpha-up",
        type=float,
        default=1.0,
        help="scaling factor for the upward model",
    )
    parser.add_argument(
        "--alpha-down",
        type=float,
        default=1.0,
        help="scaling factor for the downward model",
    )

    parser.add_argument(
        "--latent-dim", type=int, default=10, help="output size of the output layer"
    )
    parser.add_argument(
        "--is-supervised",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="supervised or unsupervised learning",
    )
    parser.add_argument(
        "--is-hybrid",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="hybrid model with down energy only for inference",
    )
    parser.add_argument(
        "--h-var",
        type=float,
        default=0.0,
        help="Langevin noise variance of the (hidden) pc layers",
    )
    parser.add_argument(
        "--input-var", type=float, default=1.0, help="input variance of the input layer"
    )
    parser.add_argument(
        "--is-cnn",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="should always be true to be changed later",
    )

    parser.add_argument(
        "--is-free-latents",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="to be used with supervised model to combine supervised and unsupervised learning",
    )
    parser.add_argument(
        "--free-latent-dim", type=int, default=30, help="size of the free latent space"
    )

    args = parser.parse_args()
    main(args)
