import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import argparse
from tqdm import tqdm
import random

import wandb

import numpy as np

import jax
import optax

import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.functional as pxf
import pcax.utils as pxu


from utils_pcax.data import get_datax_mlp, make_mean_images
from utils_pcax.utils import set_seed
from utils_pcax.models import (
    AE,
    AE_Latent,
    AddLatent,
    Model,
    energy_ae_latent,
    energy_down,
    initialisation,
)
from utils_pcax.eval import (
    eval_acc,
    eval_latent_decoding_acc,
    eval_latent_reconstruction,
    eval_latent_reconstruction_free_latent,
    eval_rmse,
)

"""
    This script can be used to train:
        - Autoencoder
    This model can by train in supervised, unsupervised or combined supervised/unsupervised mode on MNIST or Fashion MNIST.
"""


def setup_train(energy_weights):
    @pxf.jit()
    def train_on_batch(
        x: jax.Array,
        y: jax.Array,
        *,
        model: Model,
        optim_w: pxu.Optim,
    ):

        model.train()

        # initialise the vodes
        with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
            initialisation(x, y, model=model)

        # Learning step
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True
            )(energy_weights)(x, y, model=model)
        optim_w.step(model, g["model"])
        return e

    def train(dl, *, model: Model, optim_w: pxu.Optim, verbose: bool = False):
        for x, y in dl:
            e = train_on_batch(x, y, model=model, optim_w=optim_w)
        return 0.0, 0.0

    return train


def main(args):
    if args.is_wandb:
        wandb.init(project="test")
        for key, value in wandb.config.items():
            setattr(args, key, value)
        wandb.config.update(args)

    seed = args.seed
    set_seed(seed)

    verbose = args.verbose
    is_wandb = args.is_wandb

    # check model type
    is_supervised = args.is_supervised
    is_hybrid = args.is_hybrid
    is_free_latents = args.is_free_latents

    nm_epochs = args.nm_epochs
    epochs_per_val = args.epochs_per_val

    batch_size = args.batch_size
    activity_init = args.activity_init
    activity_init_kwargs = {}

    lr_p = args.lr_p
    weight_decay = args.weight_decay

    is_shared_weights = args.is_shared_weights
    input_var = args.input_var
    activation = args.activation_fn
    latent_dim = args.latent_dim
    hidden_dim = args.hidden_dim
    data_dim = args.data_dim
    alpha_up = args.alpha_up
    alpha_down = args.alpha_down
    nm_layers = args.nm_layers

    dataset = args.dataset

    model_latent_dim = args.free_latent_dim
    latent_init = args.activity_init

    # some checks
    assert alpha_up + alpha_down > 0.0, "At least one of the energies should be active"
    if not is_supervised:
        assert (
            alpha_down > 0.0
        ), "amortised inference preffered setup is with alpha_up=1 and alpha_down=1"
    if is_free_latents:
        assert is_supervised, "Free latents are only used in supervised learning"
        assert alpha_up > 0 and alpha_down > 0

    model = Model(
        input_dim=latent_dim,
        hidden_dim=hidden_dim,
        output_dim=data_dim,
        nm_layers=nm_layers,
        activation=activation,
        input_var=input_var,
        alpha_up=alpha_up,
        alpha_down=alpha_down,
        is_supervised=is_supervised,
        is_shared_weights=is_shared_weights,
        activity_init=activity_init,
        activity_init_kwargs=activity_init_kwargs,
    )
    if is_free_latents:
        model = AddLatent(
            model,
            model_latent_dim,
            latent_init,
            latent_var=alpha_up / alpha_down,
            is_stop_gradient=False,
        )
        model.latent_vode.h.latent = True

    # convert pc to bp by removing the pc layers
    if is_supervised:
        # update model.down
        down_layers = [model.down[0]]
        for layer in model.down[1:-1]:
            if not isinstance(layer, pxc.Vode):
                down_layers.append(layer)
            else:  # needed for free latent model to know where to join model.down
                down_layers.append(px.static(lambda x: x))
        down_layers.append(model.down[-1])
        model.down = down_layers

        # update model.up
        up_layers = [model.up[0]]
        for layer in model.up[1:-1]:
            if not isinstance(layer, pxc.Vode):
                up_layers.append(layer)
            else:
                up_layers.append(px.static(lambda x: x))
        up_layers.append(model.up[-1])
        model.up = up_layers

        # clean vodes
        model.vodes = [model.vodes[0], model.vodes[-1]]
    else:
        # update model.down
        down_layers = [model.down[0]]
        for layer in model.down[1:]:
            if not isinstance(layer, pxc.Vode):
                down_layers.append(layer)
        down_layers.append(model.down[-1])
        model.down = down_layers

        # update model.up
        up_layers = [model.up[0]]
        for layer in model.up[1:-1]:
            if not isinstance(layer, pxc.Vode):
                up_layers.append(layer)
        up_layers += [model.up[-1]]
        model.up = up_layers

        # clean vodes
        model.vodes = [model.vodes[0], model.vodes[-1]]

    model = AE(model) if not is_supervised else AE_Latent(model)

    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        initialisation(
            jax.numpy.zeros((batch_size, latent_dim)),
            jax.numpy.zeros((batch_size, 784)),
            model=model,
        )
        optim_w = pxu.Optim(
            optax.adamw(lr_p, weight_decay=weight_decay),
            pxu.Mask(pxnn.LayerParam)(model),
        )

    train_dl, val_dl, test_dl = get_datax_mlp(args)

    # Create a copy of args with is_supervised overridden to True
    if not is_supervised:
        args_supervised = argparse.Namespace(**vars(args))
        args_supervised.is_supervised = True
        _, val_dl_labelled, test_dl_labelled = get_datax_mlp(args_supervised)

    # setup learning
    if not is_supervised:
        weights_energy = energy_down
    else:
        if is_free_latents:
            weights_energy = energy_ae_latent
        else:
            raise NotImplementedError(
                "Only free latents are implemented for unsupervised learning"
            )
    train = setup_train(weights_energy)

    if args.make_mean_image and is_supervised:
        make_mean_images(args, mode="val", verbose=verbose)
        make_mean_images(args, mode="test", verbose=verbose)

    nm_report = 30
    epoch_dl = tqdm(np.arange(-1, nm_epochs)) if verbose else np.arange(-1, nm_epochs)
    for e in epoch_dl:
        random.shuffle(train_dl)
        e_up, e_down = (
            train(train_dl, model=model, optim_w=optim_w, verbose=False)
            if e > -1
            else (0.0, 0.0)
        )

        if verbose:
            epoch_dl.set_description_str(
                f"Epochs {e+1}/{nm_epochs}: E up {e_up/(alpha_up+1e-6):.3f}, E down {e_down/(alpha_down+1e-6):.3f}"
            )

        if e % epochs_per_val == epochs_per_val - 1:
            res = {}
            if is_supervised:
                res["val_acc_ff"] = eval_acc(
                    val_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_rmse_ff"], imgs_ff = eval_rmse(
                    model,
                    batch_size,
                    dataset,
                    "val",
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_combined_err_ff"] = (1 - res["val_acc_ff"]) * 2 + res[
                    "val_rmse_ff"
                ]
                res["img_ff"] = [
                    wandb.Image(np.array(image), mode="L") for image in imgs_ff
                ]
                if is_free_latents:
                    res["val_recon_mse_free_ff"], imgs_ff_free = (
                        eval_latent_reconstruction_free_latent(
                            val_dl,
                            model,
                            mode="fp",
                            mode_kwargs={"noise_var": 0.0},
                            verbose=True,
                        )
                    )
                    res["val_recon_ff_free"] = [
                        wandb.Image(np.array(image).reshape(28, 28), mode="L")
                        for image in imgs_ff_free[:nm_report]
                    ]
                    res["val_combined_err_inf_free"] = (
                        1 - res["val_acc_ff"]
                    ) * 2 + res["val_recon_mse_free_ff"]
            else:
                res["val_acc_ff"] = eval_latent_decoding_acc(
                    val_dl, val_dl_labelled, model, mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_recon_mse_ff"], recon_ff = eval_latent_reconstruction(
                    val_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_recon_ff"] = [
                    wandb.Image(np.array(image).reshape(28, 28), mode="L")
                    for image in recon_ff[:nm_report]
                ]
            if not is_supervised or is_free_latents:
                if e == -1:
                    _, y = val_dl[0]
                    res["val_base_imgs"] = [
                        wandb.Image(np.array(image).reshape(28, 28), mode="L")
                        for image in y[:nm_report]
                    ]

            if is_wandb:
                wandb.log(res)

    # get final results
    # print("reporting test results")
    res = {}

    if is_supervised:
        res["test_acc_ff"] = eval_acc(test_dl, model, verbose=verbose)
        res["test_rmse_ff"], imgs_ff = eval_rmse(
            model, batch_size, dataset, "test", verbose=verbose
        )
        res["img_ff"] = [wandb.Image(np.array(image), mode="L") for image in imgs_ff]
        if is_free_latents:
            res["test_recon_mse_free_ff"], imgs_ff_free = (
                eval_latent_reconstruction_free_latent(
                    test_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=True,
                )
            )
            res["test_recon_ff_free"] = [
                wandb.Image(np.array(image).reshape(28, 28), mode="L")
                for image in imgs_ff_free[:nm_report]
            ]
    else:
        res["test_acc_ff"] = eval_latent_decoding_acc(
            test_dl, test_dl_labelled, model,
            mode="fp",
            mode_kwargs={"noise_var": 0.0},
            verbose=verbose,
        )
        res["test_recon_mse_ff"], recon_ff = eval_latent_reconstruction(
            test_dl, model, mode="fp", mode_kwargs={"noise_var": 0.0}, verbose=verbose
        )
        res["test_recon_ff"] = [
            wandb.Image(np.array(image).reshape(28, 28), mode="L")
            for image in recon_ff[:nm_report]
        ]
    if not is_supervised or is_free_latents:
        _, y = test_dl[0]
        res["test_base_imgs"] = [
            wandb.Image(np.array(image).reshape(28, 28), mode="L")
            for image in y[:nm_report]
        ]

    if is_wandb:
        wandb.log(res)
        wandb.finish()

    return model, optim_w


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Training a bidirectional MLP model.", fromfile_prefix_chars="@"
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="set seed for reproducibility"
    )

    parser.add_argument(
        "--dataset",
        type=str,
        default="fashion_mnist",
        choices=["mnist", "fashion_mnist"],
        help="dataset to use for training",
    )
    parser.add_argument(
        "--is-up-initialisation-default",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="initialisation of the model activity with the upward pass or downward pass",
    )
    parser.add_argument("--train-size", type=int, default=10000, help="training size")
    parser.add_argument("--val-size", type=int, default=1000, help="validation size")
    parser.add_argument("--test-size", type=int, default=1000, help="test size")
    parser.add_argument(
        "--batch-size", type=int, default=256, help="training batch size"
    )
    parser.add_argument(
        "--nm-epochs", type=int, default=25, help="number of epochs to train"
    )
    parser.add_argument(
        "--lr-x", type=float, default=0.01, help="learning rate of the latent state x"
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.0,
        help="learning rate of the latent state x",
    )
    parser.add_argument(
        "--T", type=int, default=8, help="number of inference iterations"
    )
    parser.add_argument(
        "--T-eval",
        type=int,
        default=100,
        help="number of inference iterations at evaluation time",
    )
    parser.add_argument(
        "--lr-p",
        type=float,
        default=0.001,
        help="learning rate of the model parameters",
    )
    parser.add_argument(
        "--weight-decay",
        type=float,
        default=0.0000,
        help="weight decay of the model parameters",
    )
    parser.add_argument(
        "--activity-decay",
        type=float,
        default=0.0,
        help="weight decay of the model parameters",
    )
    parser.add_argument(
        "--data-dim", type=int, default=784, help="input size of the input layer"
    )
    parser.add_argument(
        "--hidden-dim", type=int, default=256, help="hidden size of the hidden layers"
    )
    parser.add_argument(
        "--nm-layers", type=int, default=4, help="number of hidden layers"
    )
    parser.add_argument(
        "--activation-fn",
        type=str,
        default="tanh",
        choices=["relu", "tanh", "l-relu", "gelu"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--is-wandb",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="log the results to wandb",
    )
    parser.add_argument(
        "--verbose",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="print the results to the console",
    )
    parser.add_argument(
        "--epochs-per-val",
        type=int,
        default=5,
        help="number of epochs between validation",
    )
    parser.add_argument(
        "--is-shared-weights",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="use shared weights for the upward and downward models",
    )
    parser.add_argument(
        "--activity-init",
        type=str,
        default="ff",
        choices=["ff", "zero", "randn", "noisy-ff", "xavier"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--make-mean-image",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="recompute the mean image per class of dataset",
    )

    parser.add_argument(
        "--alpha-up",
        type=float,
        default=1.0,
        help="scaling factor for the upward model",
    )
    parser.add_argument(
        "--alpha-down",
        type=float,
        default=1.0,
        help="scaling factor for the downward model",
    )

    parser.add_argument(
        "--latent-dim", type=int, default=30, help="output size of the output layer"
    )
    parser.add_argument(
        "--is-supervised",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="supervised or unsupervised learning",
    )
    parser.add_argument(
        "--is-hybrid",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="hybrid model with down energy only for inference",
    )
    parser.add_argument(
        "--h-var",
        type=float,
        default=0.0,
        help="Langevin noise variance of the (hidden) pc layers",
    )
    parser.add_argument(
        "--input-var", type=float, default=1.0, help="input variance of the input layer"
    )
    parser.add_argument(
        "--is-cnn",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="should always be true to be changed later",
    )

    parser.add_argument(
        "--is-free-latents",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="to be used with supervised model to combine supervised and unsupervised learning",
    )
    parser.add_argument(
        "--free-latent-dim", type=int, default=30, help="size of the free latent space"
    )

    args = parser.parse_args()
    main(args)
