from functools import partial
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import argparse
from tqdm import tqdm

import wandb

import numpy as np

import jax
import optax

import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.functional as pxf
import pcax.utils as pxu


from utils_pcax.data import get_datax_cnn, make_mean_images
from utils_pcax.utils import set_seed
from utils_pcax_archive.models import (
    AE,
    AE_Latent,
    AddLatent,
    CNNModel,
    Model,
    energy_ae_latent,
    initialisation,
    energy_down,
)
from utils_pcax.eval import (
    eval_acc,
    eval_latent_decoding_acc,
    eval_latent_reconstruction,
    eval_latent_reconstruction_free_latent,
    eval_rmse,
)

"""
    This script can be used to train convolutional models with the following architectures:
        - AE
    This model can by train in supervised, unsupervised or combined supervised/unsupervised mode on CIFAR10 and CIFAR-100.
"""


def setup_train(energy_weights, is_free_latents=False, optim_w_latent=None):
    @pxf.jit()
    def train_on_batch(
        x: jax.Array,
        y: jax.Array,
        *,
        model: Model,
        optim_w: pxu.Optim,
    ):

        model.train()

        # initialise the vodes with x and y
        with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
            initialisation(x, y, model=model)

        # Learning step
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True
            )(energy_weights)(x, y, model=model)
        optim_w.step(model, g["model"])
        return e

    @pxf.jit()
    def train_on_batch_free(
        x: jax.Array,
        y: jax.Array,
        *,
        model: Model,
        optim_w: pxu.Optim,
        optim_w_latent: pxu.Optim,
    ):

        model.train()

        # initialise the vodes with x and y
        with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
            initialisation(x, y, model=model)

        # Learning step
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True
            )(energy_weights)(x, y, model=model)
        optim_w.step(
            model,
            pxu.Mask(pxu.m(pxnn.LayerParam).has_not(latent=True).has_not(frozen=True))(
                g["model"]
            ),
            True,
        )
        optim_w_latent.step(
            model, pxu.Mask(pxu.m(pxnn.LayerParam).has(latent=True))(g["model"]), True
        )
        return e

    local_train = (
        partial(train_on_batch_free, optim_w_latent=optim_w_latent)
        if is_free_latents
        else train_on_batch
    )

    def train(dl, *, model: Model, optim_w: pxu.Optim, verbose: bool = False):
        for x, y in dl:
            e = local_train(x, y, model=model, optim_w=optim_w)
        return 0.0, 0.0

    return train


def cosine_scedule(lr_p, nm_epochs, train_dl):
    if nm_epochs == 0:
        return optax.constant_schedule(lr_p)
    else:
        return optax.warmup_cosine_decay_schedule(
            init_value=lr_p,
            peak_value=1.1 * lr_p,
            warmup_steps=0.1 * len(train_dl) * nm_epochs,
            decay_steps=len(train_dl) * nm_epochs,
            end_value=0.1 * lr_p,
            exponent=1.0,
        )


def main(args):
    if args.is_wandb:
        wandb.init(project="test")
        for key, value in wandb.config.items():
            setattr(args, key, value)
        wandb.config.update(args)

    seed = args.seed
    set_seed(seed)

    verbose = args.verbose
    is_wandb = args.is_wandb

    # check model type
    is_supervised = args.is_supervised
    is_free_latents = args.is_free_latents
    is_stop_gradient_free = args.is_stop_gradient_free

    nm_epochs = args.nm_epochs
    epochs_per_val = args.epochs_per_val

    batch_size = args.batch_size
    activity_init = args.activity_init
    activity_init_kwargs = {}

    lr_p = np.round(args.lr_p, 8)  # for reproducibility
    weight_decay = np.round(args.weight_decay, 8)

    model_name = args.model_name
    input_var = args.input_var
    activation = args.activation_fn
    latent_dim = args.latent_dim
    data_dim = args.data_dim
    alpha_up = args.alpha_up
    alpha_down = args.alpha_down

    dataset = args.dataset
    input_channels = 1 if dataset.split("_")[-1] == "mnist" else 3

    model_latent_dim = args.free_latent_dim
    latent_init = args.activity_init

    # some checks
    assert alpha_up + alpha_down > 0.0, "At least one of the energies should be active"
    if not is_supervised:
        assert (
            alpha_down > 0.0
        ), "amortised inference preffered setup is with alpha_up=1 and alpha_down=1"
        assert (
            not is_free_latents
        ), "Free latents are not supported for unsupervised learning"
    else:
        assert is_free_latents
        assert alpha_up > 0 and alpha_down > 0

    model = CNNModel(
        input_size=data_dim,
        output_size=latent_dim,
        input_channels=input_channels,
        cnn_name=model_name,
        activation=activation,
        input_var=input_var,
        alpha_up=alpha_up,
        alpha_down=alpha_down,
        is_supervised=is_supervised,
        activity_init=activity_init,
        activity_init_kwargs=activity_init_kwargs,
    )
    if is_free_latents:
        model = AddLatent(
            model,
            model_latent_dim,
            latent_init,
            latent_var=alpha_up / alpha_down,
            is_stop_gradient=is_stop_gradient_free,
        )
        model.latent_vode.h.latent = True
        model.latent_layer_down.nn.weight.latent = True
        model.latent_layer_down.nn.bias.latent = True
        model.latent_layer_up.nn.weight.latent = True
        model.latent_layer_up.nn.bias.latent = True
        for l in model.down:
            if isinstance(l, pxnn.Layer):
                if hasattr(l.nn, "weight"):
                    l.nn.weight.latent = True
                    l.nn.bias.latent = True

    # convert pc to bp by removing the pc layers
    if is_supervised:
        # update model.down
        down_layers = [model.down[0]]
        for layer in model.down[1:-1]:
            if not isinstance(layer, pxc.Vode):
                down_layers.append(layer)
            else:  # needed for free latent model to know where to join model.down
                down_layers.append(px.static(lambda x: x))

        down_layers.append(model.down[-1])
        model.down = down_layers

        # update model.up
        up_layers = [model.up[0]]
        for layer in model.up[1:-1]:
            if not isinstance(layer, pxc.Vode):
                up_layers.append(layer)
            else:
                up_layers.append(px.static(lambda x: x))
        up_layers.append(model.up[-1])
        model.up = up_layers

        # clean vodes
        model.vodes = [model.vodes[0], model.vodes[-1]]
    else:
        # update model.down
        down_layers = [model.down[0]]
        for layer in model.down[1:]:
            if not isinstance(layer, pxc.Vode):
                down_layers.append(layer)
        down_layers.append(model.down[-1])
        model.down = down_layers

        # update model.up
        up_layers = [model.up[0]]
        for layer in model.up[1:-1]:
            if not isinstance(layer, pxc.Vode):
                up_layers.append(layer)
        up_layers += [model.up[-1]]
        model.up = up_layers

        # clean vodes
        model.vodes = [model.vodes[0], model.vodes[-1]]

    model = AE(model) if not is_supervised else AE_Latent(model)

    train_dl, val_dl, test_dl = get_datax_cnn(args)

    # Create a copy of args with is_supervised overridden to True
    if not is_supervised:
        args_supervised = argparse.Namespace(**vars(args))
        args_supervised.is_supervised = True
        _, val_dl_labelled, test_dl_labelled = get_datax_cnn(args_supervised)

    lr_p = cosine_scedule(lr_p, nm_epochs, train_dl)

    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        initialisation(
            jax.numpy.zeros((batch_size, latent_dim)),
            jax.numpy.zeros((batch_size, input_channels, *data_dim)),
            model=model,
        )
        optim_w = pxu.Optim(
            optax.adamw(lr_p, weight_decay=weight_decay),
            pxu.Mask(pxu.m(pxnn.LayerParam).has_not(latent=True).has_not(frozen=True))(
                model
            ),
        )
        optim_w_latent = (
            None
            if not is_free_latents
            else pxu.Optim(
                optax.adamw(
                    cosine_scedule(args.lr_p_latent, nm_epochs, train_dl),
                    weight_decay=weight_decay,
                ),
                pxu.Mask(pxu.m(pxnn.LayerParam).has(latent=True))(model),
            )
        )

    if args.make_mean_image and is_supervised:
        make_mean_images(args, mode="val", verbose=verbose)
        make_mean_images(args, mode="test", verbose=verbose)

    # setup learning
    if not is_supervised:
        weights_energy = energy_down
    else:
        if is_free_latents:
            weights_energy = energy_ae_latent
        else:
            raise NotImplementedError(
                "Only free latents are implemented for supervised learning"
            )
    train = setup_train(
        weights_energy, is_free_latents=is_free_latents, optim_w_latent=optim_w_latent
    )

    nm_img_report = 30
    e0 = None  # used to assess divergence
    epoch_dl = tqdm(np.arange(-1, nm_epochs)) if verbose else np.arange(-1, nm_epochs)
    for e in epoch_dl:
        e_up, e_down = (
            train(train_dl, model=model, optim_w=optim_w, verbose=False)
            if e > -1
            else (0.0, 0.0)
        )

        if verbose:
            epoch_dl.set_description_str(
                f"Epochs {e+1}/{nm_epochs}: E up {e_up/(alpha_up+1e-6):.3f}, E down {e_down/(alpha_down+1e-6):.3f}"
            )
        if e % epochs_per_val == epochs_per_val - 1:
            res = {}
            if is_supervised:
                res["val_acc_ff"] = eval_acc(
                    val_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_rmse_ff"], imgs_ff = eval_rmse(
                    model,
                    batch_size,
                    dataset,
                    "val",
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )

                res["val_combined_err_ff"] = (1 - res["val_acc_ff"]) * 2 + res[
                    "val_rmse_ff"
                ]

                res["img_ff"] = (
                    [wandb.Image(np.array(image), mode="L") for image in imgs_ff]
                    if dataset.split("_")[-1] == "mnist"
                    else [
                        wandb.Image(
                            np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                        )
                        for image in imgs_ff
                    ]
                )
                if is_free_latents:
                    res["val_recon_mse_ff"], recon_ff = (
                        eval_latent_reconstruction_free_latent(
                            val_dl,
                            model,
                            mode="fp",
                            mode_kwargs={"noise_var": 0.0},
                            verbose=verbose,
                        )
                    )
                    res["val_recon_ff"] = (
                        [
                            wandb.Image(np.array(image), mode="L")
                            for image in recon_ff[:nm_img_report]
                        ]
                        if dataset.split("_")[-1] == "mnist"
                        else [
                            wandb.Image(
                                np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                            )
                            for image in recon_ff[:nm_img_report]
                        ]
                    )
                    res["val_combined_err_inf_free"] = (
                        1 - res["val_acc_ff"]
                    ) * 2 + res["val_recon_mse_ff"]
            else:
                res["val_acc_ff"] = eval_latent_decoding_acc(
                    val_dl, val_dl_labelled, model, mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )

                res["val_recon_mse_ff"], recon_ff = eval_latent_reconstruction(
                    val_dl,
                    model,
                    mode="fp",
                    mode_kwargs={"noise_var": 0.0},
                    verbose=verbose,
                )
                res["val_recon_ff"] = (
                    [
                        wandb.Image(np.array(image), mode="L")
                        for image in recon_ff[:nm_img_report]
                    ]
                    if dataset.split("_")[-1] == "mnist"
                    else [
                        wandb.Image(
                            np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                        )
                        for image in recon_ff[:nm_img_report]
                    ]
                )
            if not is_supervised or is_free_latents:
                if e == -1:
                    _, y = next(iter(val_dl))
                    res["val_base_imgs"] = (
                        [
                            wandb.Image(np.array(image), mode="L")
                            for image in y[:nm_img_report]
                        ]
                        if dataset.split("_")[-1] == "mnist"
                        else [
                            wandb.Image(
                                np.transpose(np.array(image), (1, 2, 0)), mode="RGB"
                            )
                            for image in y[:nm_img_report]
                        ]
                    )
            if is_wandb:
                wandb.log(res)

    # get final results
    # print("reporting test results")
    res = {}
    if is_supervised:
        res["test_acc_ff"] = eval_acc(test_dl, model, verbose=verbose)
        res["test_rmse_ff"], imgs_ff = eval_rmse(
            model, batch_size, dataset, "test", verbose=verbose
        )
        res["img_ff"] = (
            [wandb.Image(np.array(image), mode="L") for image in imgs_ff]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in imgs_ff
            ]
        )
        if is_free_latents:
            res["test_recon_mse_ff"], recon_ff = eval_latent_reconstruction_free_latent(
                test_dl,
                model,
                mode="fp",
                mode_kwargs={"noise_var": 0.0},
                verbose=verbose,
            )
            res["test_recon_ff"] = (
                [wandb.Image(np.array(image), mode="L") for image in recon_ff]
                if dataset.split("_")[-1] == "mnist"
                else [
                    wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                    for image in recon_ff[:nm_img_report]
                ]
            )
    else:
        res["test_acc_ff"] = eval_latent_decoding_acc(
            test_dl, test_dl_labelled, model,
            mode="fp",
            mode_kwargs={"noise_var": 0.0},
            verbose=verbose,
        )
        res["test_recon_mse_ff"], recon_ff = eval_latent_reconstruction(
            test_dl, model, mode="fp", mode_kwargs={"noise_var": 0.0}, verbose=verbose
        )
        res["recon_ff"] = (
            [wandb.Image(np.array(image), mode="L") for image in recon_ff]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in recon_ff[:nm_img_report]
            ]
        )
    if not is_supervised or is_free_latents:
        _, y = next(iter(test_dl))
        res["test_base_imgs"] = (
            [wandb.Image(np.array(image), mode="L") for image in y]
            if dataset.split("_")[-1] == "mnist"
            else [
                wandb.Image(np.transpose(np.array(image), (1, 2, 0)), mode="RGB")
                for image in y[:nm_img_report]
            ]
        )

    if is_wandb:
        wandb.log(res)
        wandb.finish()

    return model, optim_w


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Training a bidirectional MLP model.", fromfile_prefix_chars="@"
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="set seed for reproducibility"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10",
        choices=["mnist", "fashion_mnist", "cifar10", "cifar100"],
        help="dataset to use for training",
    )
    parser.add_argument(
        "--is-up-initialisation-default",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="initialisation of the model activity with the upward pass or downward pass",
    )
    parser.add_argument("--train-size", type=int, default=50000, help="training size")
    parser.add_argument("--val-size", type=int, default=5000, help="validation size")
    parser.add_argument("--test-size", type=int, default=5000, help="test size")
    parser.add_argument(
        "--batch-size", type=int, default=256, help="training batch size"
    )
    parser.add_argument(
        "--nm-epochs", type=int, default=50, help="number of epochs to train"
    )
    parser.add_argument(
        "--lr-x", type=float, default=0.01, help="learning rate of the latent state x"
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.0,
        help="learning rate of the latent state x",
    )
    parser.add_argument(
        "--T", type=int, default=8, help="number of inference iterations"
    )
    parser.add_argument(
        "--T-eval",
        type=int,
        default=20,
        help="number of inference iterations at evaluation time",
    )
    parser.add_argument(
        "--lr-p",
        type=float,
        default=0.0006644297056700742,
        help="learning rate of the model parameters",
    )
    parser.add_argument(
        "--weight-decay",
        type=float,
        default=0.0008104328387325853,
        help="weight decay of the model parameters",
    )
    parser.add_argument(
        "--activity-decay",
        type=float,
        default=0.0,
        help="weight decay of the model parameters",
    )
    parser.add_argument(
        "--data-dim", type=int, default=(32, 32), help="input size of the input layer"
    )
    parser.add_argument(
        "--activation-fn",
        type=str,
        default="tanh",
        choices=["relu", "tanh", "l-relu", "gelu"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--is-wandb",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="log the results to wandb",
    )
    parser.add_argument(
        "--verbose",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="print the results to the console",
    )
    parser.add_argument(
        "--epochs-per-val",
        type=int,
        default=5,
        help="number of epochs between validation",
    )
    parser.add_argument(
        "--activity-init",
        type=str,
        default="ff",
        choices=["ff", "zero", "randn", "noisy-ff", "xavier"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--make-mean-image",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="recompute the mean image per class of dataset",
    )

    parser.add_argument(
        "--model-name",
        type=str,
        default="SmallVGG",
        choices=[
            "SmallVGG",
            "TestCNN",
            "MonoVGG",
            "MiniVGG",
            "VGG11",
            "VGG5",
            "VGG5np",
        ],
        help="activation function of the hidden layers",
    )

    parser.add_argument(
        "--alpha-up",
        type=float,
        default=1.0,
        help="scaling factor for the upward model",
    )
    parser.add_argument(
        "--alpha-down",
        type=float,
        default=1.0,
        help="scaling factor for the downward model",
    )

    parser.add_argument(
        "--latent-dim", type=int, default=10, help="output size of the output layer"
    )
    parser.add_argument(
        "--is-supervised",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="supervised or unsupervised learning",
    )
    parser.add_argument(
        "--is-hybrid",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="hybrid model with down energy only for inference",
    )
    parser.add_argument(
        "--h-var",
        type=float,
        default=0.0,
        help="Langevin noise variance of the (hidden) pc layers",
    )
    parser.add_argument(
        "--input-var", type=float, default=1.0, help="input variance of the input layer"
    )
    parser.add_argument(
        "--is-cnn",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="should always be true to be changed later",
    )

    parser.add_argument(
        "--is-free-latents",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="to be used with supervised model to combine supervised and unsupervised learning",
    )
    parser.add_argument(
        "--is-stop-gradient-free",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="to be used with supervised model to combine supervised and unsupervised learning",
    )
    parser.add_argument(
        "--free-latent-dim", type=int, default=128, help="size of the free latent space"
    )
    parser.add_argument(
        "--lr-p-latent",
        type=float,
        default=0.001,
        help="learning rate of the model parameters",
    )

    args = parser.parse_args()
    main(args)
