import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
import jax.numpy as jnp
import optax

import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.functional as pxf
import pcax.utils as pxu

import matplotlib.pyplot as plt
import numpy as np
import wandb
from tqdm import tqdm
import random
import os
import argparse

import warnings

import torch

from utils_pcax.data import get_datax_mlp, make_mean_images
from utils_pcax.utils import sgdld, set_seed


"""
    This script can be used to train:
        - bimodla genPC on MNIST or Fashion MNIST.
"""


class Model(pxc.EnergyModule):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        activation: str,
        input_var=1.0,
        is_supervised=True,
        direction: str = "down",
        activity_init="ff",
        activity_init_kwargs={},
        latent_init="xavier",
    ) -> None:
        super().__init__()

        def se_energy_input(vode, rkg: px.RandomKeyGenerator = px.RKG):
            """Squared error energy function derived from a Gaussian distribution."""
            e = vode.get("h") - vode.get("u")
            return 0.5 * (e * e) / input_var

        self.direction = px.static(direction)

        if activation == "relu":
            activation = jax.nn.relu
        elif activation == "tanh":
            activation = jax.nn.tanh
        elif activation == "silu":
            activation = jax.nn.silu
        elif activation == "l-relu":
            activation = jax.nn.leaky_relu
        elif activation == "h-tanh":
            activation = jax.nn.hard_tanh
        elif activation == "linear":
            activation = lambda x: x
        else:
            activation = getattr(jax.nn, activation)

        self.act_fn = px.static(activation)

        if activity_init == "ff":
            ruleset = {}
            tforms = {}
            tforms_out = {}
            tforms_in = {}
        elif activity_init == "zero":
            ruleset = {pxc.STATUS.INIT: ("h, u <- u:to_zero",)}
            tforms = {"to_zero": lambda n, k, v, rkg: jnp.zeros(n.shape.get())}
            tforms_out = {"to_zero": lambda n, k, v, rkg: -jnp.ones(n.shape.get())}
            tforms_in = {
                "to_zero": lambda n, k, v, rkg: (
                    jnp.zeros(n.shape.get())
                    if is_supervised
                    else 0.1 * jnp.ones(n.shape.get())
                )
            }
        else:
            raise ValueError(f"Unknown activity_init: {activity_init}")

        self.input_vode = pxc.Vode(
            (input_dim,),
            ruleset=ruleset,
            tforms=tforms_in,
        )

        self.output_vode = pxc.Vode(
            (output_dim,),
            ruleset=ruleset,
            tforms=tforms_out,
        )

        self.latent_vode = pxc.Vode(
            (hidden_dim,),
            energy_fn=se_energy_input,
            ruleset=ruleset,
            tforms=tforms,
        )

        self.layer_input = pxnn.Linear(hidden_dim, input_dim)
        self.layer_output = pxnn.Linear(hidden_dim, output_dim)
        self.input_act = px.static(lambda x: x)
        self.output_act = px.static(jax.nn.tanh)

        self.input_vode.h.frozen = True
        self.output_vode.h.frozen = True

        self.latent_limit = jnp.sqrt(6 / hidden_dim)

        latent_zero = lambda x: jnp.zeros(self.latent_vode.shape)
        latent_xavier = lambda x: jax.random.uniform(
            px.RKG(),
            shape=(self.latent_vode.shape),
            minval=-self.latent_limit,
            maxval=self.latent_limit,
        )
        self.latent_init = (
            px.static(latent_zero)
            if latent_init == "zero"
            else px.static(latent_xavier)
        )

    def __call__(self, x, y, is_input_initialisation: bool = False):
        # return jnp.zeros_like(x)
        if x is not None:
            self.input_vode.set("h", x)
        if y is not None:
            self.output_vode.set("h", y)

        self.latent_vode.set("h", self.latent_init(None))
        outputs = self.model_down(x, y)

        if x is not None:
            self.input_vode.set("h", x)
        if y is not None:
            self.output_vode.set("h", y)
        return outputs

    def model_down(self, x, y):
        # x is label and y is data
        latent = self.act_fn(
            self.latent_vode(self.latent_vode.get("h"))
        )  # has zero energy, assume uniform distribution

        input = self.input_vode(self.input_act(self.layer_input(latent)))
        output = self.output_vode(self.output_act(self.layer_output(latent)))

        return self.input_vode.get("u"), self.output_vode.get("u")


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(0, 0),
)
def initialisation(x, y, *, model: Model, is_input_initialisation: bool = True):
    return model(x, y, is_input_initialisation=is_input_initialisation)


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(None, 0, 0),
    axis_name="batch",
)
def energy(x, y, *, model: Model):
    x_, y_ = model.model_down(x, y)
    return jax.lax.pmean(model.energy(), "batch"), x_, y_


@pxf.jit(static_argnums=(0, 3, 4))
def infer_on_batch(
    T: int,
    x: jax.Array,
    y: jax.Array,
    is_input_initialisation: bool,
    mode: int,
    *,
    model: Model,
    optim_h: pxu.Optim,
):
    mode_mapping = {
        0: "constrained",
        1: "label-only",
        2: "data-only",
        3: "unconstrained",
    }
    mode = mode_mapping.get(mode, mode)

    def h_step(i, x, y, *, model, optim_h):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, (x_, y_)), g = pxf.value_and_grad(
                pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                has_aux=True,
            )(energy)(x, y, model=model)
        optim_h.step(model, g["model"], True)
        return (x, y), None

    model.train()

    if mode == "constrained":
        model.input_vode.h.frozen = True
        model.output_vode.h.frozen = True
    elif mode == "label-only":
        model.input_vode.h.frozen = True
        model.output_vode.h.frozen = False
    elif mode == "data-only":
        model.input_vode.h.frozen = False
        model.output_vode.h.frozen = True
    elif mode == "unconstrained":
        model.input_vode.h.frozen = False
        model.output_vode.h.frozen = False

    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        initialisation(
            x, y, model=model, is_input_initialisation=is_input_initialisation
        )

    optim_h.init(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True))(model))

    # Inference steps
    pxf.scan(h_step, xs=jax.numpy.arange(T))(x, y, model=model, optim_h=optim_h)

    optim_h.clear()

    # restore frozen states
    model.input_vode.h.frozen = True
    model.output_vode.h.frozen = True
    return model.input_vode.get("h"), model.output_vode.get("h")


@pxf.jit(static_argnums=(0, 3))
def train_on_batch(
    T: int,
    x: jax.Array,
    y: jax.Array,
    is_input_initialisation: bool,
    *,
    model: Model,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim,
):

    model.train()

    infer_on_batch(T, x, y, is_input_initialisation, 0, model=model, optim_h=optim_h)

    # Learning step
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        (e, y_), g = pxf.value_and_grad(
            pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True
        )(energy)(x, y, model=model)
    optim_w.step(model, g["model"])
    return e


def train(
    dl,
    T,
    is_input_initialisation,
    *,
    model: Model,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim,
    verbose: bool = False,
):
    dl = tqdm(dl, desc="Energy: ") if verbose else dl
    for x, y in dl:
        e = train_on_batch(
            T,
            x,
            y,
            is_input_initialisation,
            model=model,
            optim_w=optim_w,
            optim_h=optim_h,
        )
        if verbose:
            dl.set_description_str(f"Energy: {e:.2f}")
    return e


def eval_acc(test_dl, model, T, is_input_initialisation, optim_h, verbose=True):
    # assess classification accuracy
    correct_count = 0
    total_count = 0
    for x, y in test_dl:
        pred_up = infer_on_batch(
            T, None, y, is_input_initialisation, 2, model=model, optim_h=optim_h
        )[0]
        correct_count += np.sum(np.argmax(pred_up, axis=1) == np.argmax(x, axis=1))
        total_count += len(x)
    accuracy = correct_count / total_count
    if verbose:
        print(f"Accuracy: {accuracy:.2f}")
    return accuracy


def eval_gen(model, batch_size, T, is_input_initialisation, optim_h, verbose=True):
    # assess generation of mean images
    X_test = jax.nn.one_hot(np.arange(batch_size) % 10, 10)
    Y_test = np.zeros((batch_size, 784))

    pred_down = infer_on_batch(
        T, X_test, None, is_input_initialisation, 1, model=model, optim_h=optim_h
    )[1].reshape(batch_size, 28, 28)[:10]

    if verbose:
        fig, axs = plt.subplots(1, 10)
        for i in range(10):
            axs[i].imshow(pred_down[i], cmap="gray", vmin=-1, vmax=1)
            axs[i].axis("off")
        plt.show()
    return pred_down


def eval_rmse(
    model,
    batch_size,
    dataset,
    datasubset,
    T,
    is_input_initialisation,
    optim_h,
    verbose=True,
):
    img_folder = "mean_images"

    # load correct mean images
    path = f"{img_folder}/{dataset}_{datasubset}.pt"
    # try to load mean images until it works
    open_state = False
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", FutureWarning)
        for i in range(100):
            try:
                mean_images = torch.load(path).numpy()
                open_state = True
                break
            except:
                pass
        if not open_state:
            print("Could not load mean images")
            return

    MAP = eval_gen(
        model, batch_size, T, is_input_initialisation, optim_h, verbose=False
    )
    MAP = MAP / 2 + 0.5

    rmse = np.sqrt(((mean_images - MAP) ** 2).mean())
    if verbose:
        print(f"RMSE: {rmse:.2f}")
    return rmse, MAP


def main(args):
    if args.is_wandb:
        wandb.init(project="test")
        for key, value in wandb.config.items():
            setattr(args, key, value)
        wandb.config.update(args)

    set_seed(args.seed)

    verbose = args.verbose
    is_wandb = args.is_wandb
    is_supervised = args.is_supervised

    nm_epochs = args.nm_epochs
    epochs_per_val = args.epochs_per_val

    T = args.T
    T_eval = args.T_eval

    batch_size = args.batch_size
    lr = args.lr_x
    momentum = args.momentum
    gamma = 0
    h_var = args.h_var
    activity_init = args.activity_init
    activity_init_kwargs = {"layer_var": h_var}
    is_input_initialisation = args.is_input_initialisation_default
    latent_init = args.latent_init

    lr_p = args.lr_p
    weight_decay = args.weight_decay

    input_var = args.input_var
    activation = args.activation_fn
    latent_dim = args.latent_dim
    hidden_dim = args.hidden_dim
    data_dim = args.data_dim
    direction = args.direction

    dataset = args.dataset

    model = Model(
        input_dim=latent_dim,
        hidden_dim=hidden_dim,
        output_dim=data_dim,
        activation=activation,
        input_var=input_var,
        is_supervised=is_supervised,
        direction=direction,
        activity_init=activity_init,
        activity_init_kwargs=activity_init_kwargs,
        latent_init=latent_init,
    )

    optim_h = pxu.Optim(sgdld(lr, momentum, h_var, gamma))
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        initialisation(
            jax.numpy.zeros((batch_size, latent_dim)),
            jax.numpy.zeros((batch_size, data_dim)),
            model=model,
            is_input_initialisation=is_input_initialisation,
        )
        optim_w = pxu.Optim(
            optax.adamw(lr_p, weight_decay=weight_decay),
            pxu.Mask(pxnn.LayerParam)(model),
        )

    train_dl, val_dl, test_dl = get_datax_mlp(args)

    if args.make_mean_image and is_supervised:
        make_mean_images(args, mode="val", verbose=verbose)
        make_mean_images(args, mode="test", verbose=verbose)

    epoch_dl = tqdm(np.arange(-1, nm_epochs)) if verbose else np.arange(-1, nm_epochs)
    for e in epoch_dl:
        random.shuffle(train_dl)
        e_model = (
            train(
                train_dl,
                T=T,
                is_input_initialisation=is_input_initialisation,
                model=model,
                optim_w=optim_w,
                optim_h=optim_h,
                verbose=False,
            )
            if e > -1
            else 0.0
        )

        if verbose:
            epoch_dl.set_description_str(f"Epochs {e+1}/{nm_epochs}: E {e_model:.3f}")

        if e % epochs_per_val == epochs_per_val - 1:
            res = {}
            if is_supervised:
                res["val_acc_inf"] = eval_acc(
                    val_dl, model, T_eval, False, optim_h, verbose=verbose
                )
                res["val_rmse_inf"], imgs_inf = eval_rmse(
                    model,
                    batch_size,
                    dataset,
                    "val",
                    T_eval,
                    True,
                    optim_h,
                    verbose=verbose,
                )

                res["img_inf"] = [
                    wandb.Image(np.array(image), mode="L") for image in imgs_inf
                ]
            else:
                raise ValueError("Unsupervised learning not supported")

            if is_wandb:
                wandb.log(res)

    # get final results
    # print("reporting test results")
    res = {}
    if is_supervised:
        res["test_acc_inf"] = eval_acc(
            test_dl, model, T_eval, False, optim_h, verbose=verbose
        )
        res["test_rmse_inf"], imgs_inf = eval_rmse(
            model, batch_size, dataset, "test", T_eval, True, optim_h, verbose=verbose
        )

        pil_images = [wandb.Image(np.array(image), mode="L") for image in imgs_inf]
        res["img_inf"] = pil_images

    if is_wandb:
        wandb.log(res)
        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Training a bidirectional MLP model.", fromfile_prefix_chars="@"
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="set seed for reproducibility"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="mnist",
        choices=["mnist", "fashion_mnist"],
        help="dataset to use for training",
    )
    parser.add_argument(
        "--is-input-initialisation-default",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="initialisation of the model activity with the upward pass or downward pass",
    )
    parser.add_argument("--train-size", type=int, default=10000, help="training size")
    parser.add_argument("--val-size", type=int, default=1000, help="validation size")
    parser.add_argument("--test-size", type=int, default=1000, help="test size")
    parser.add_argument(
        "--batch-size", type=int, default=256, help="training batch size"
    )
    parser.add_argument(
        "--nm-epochs", type=int, default=25, help="number of epochs to train"
    )
    parser.add_argument(
        "--lr-x", type=float, default=0.01, help="learning rate of the latent state x"
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.0,
        help="learning rate of the latent state x",
    )
    parser.add_argument(
        "--T", type=int, default=100, help="number of inference iterations"
    )
    parser.add_argument(
        "--T-eval",
        type=int,
        default=1000,
        help="number of inference iterations at evaluation time",
    )
    parser.add_argument(
        "--lr-p",
        type=float,
        default=0.0001,
        help="learning rate of the model parameters",
    )
    parser.add_argument(
        "--weight-decay",
        type=float,
        default=0.0,
        help="weight decay of the model parameters",
    )
    parser.add_argument(
        "--data-dim", type=int, default=784, help="input size of the input layer"
    )
    parser.add_argument(
        "--hidden-dim", type=int, default=256, help="hidden size of the hidden layers"
    )
    parser.add_argument(
        "--activation-fn",
        type=str,
        default="linear",
        choices=["relu", "tanh", "l-relu", "linear", "gelu"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--is-wandb",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="log the results to wandb",
    )
    parser.add_argument(
        "--verbose",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="print the results to the console",
    )
    parser.add_argument(
        "--epochs-per-val",
        type=int,
        default=5,
        help="number of epochs between validation",
    )
    parser.add_argument(
        "--activity-init",
        type=str,
        default="ff",
        choices=["ff"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--latent-init",
        type=str,
        default="xavier",
        choices=["zero", "xavier"],
        help="activation function of the hidden layers",
    )
    parser.add_argument(
        "--make-mean-image",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="recompute the mean image per class of dataset",
    )

    parser.add_argument(
        "--latent-dim", type=int, default=10, help="output size of the output layer"
    )
    parser.add_argument(
        "--is-supervised",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="supervised or unsupervised learning",
    )
    parser.add_argument(
        "--h-var",
        type=float,
        default=0.0,
        help="Langevin noise variance of the (hidden) pc layers",
    )
    parser.add_argument(
        "--input-var", type=float, default=1.0, help="input variance of the input layer"
    )

    parser.add_argument(
        "--direction",
        type=str,
        default="down",
        choices=["down", "up"],
        help="direction of multimodal model",
    )

    args = parser.parse_args()
    main(args)
