import copy
from typing import Tuple, Sequence, Union
import jax
import jax.numpy as jnp
import jax.random as jax_random
import numpy as np

import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.functional as pxf
from pcax.predictive_coding._energy import se_energy
from pcax.predictive_coding._parameter import VodeParam
from pcax.predictive_coding._vode import Ruleset
import pcax.utils as pxu
import equinox as eqx

from pcax.nn import Layer
from pcax.core import RandomKeyGenerator, RKG
from functools import partial


from utils_pcax.utils import (
    sample_multivariate_Gauss,
    sample_multivariate_Gauss_diag_cov,
)


def set_init(init: str, h_var=0.0):

    def ff_randn(n, k, v, rkg):
        return v + h_var * jax_random.normal(rkg(), n.shape.get())

    ruleset = {
        "ff": ("h, u <- u",),
        "zero": ("h, u <- u:to_zero",),
        "randn": ("h, u <- u:randn",),
        "narrow_randn": ("h, u <- u:narrow_randn",),
        "xav": ("h, u <- u:xav",),
        "ff_randn": ("h, u <- u:ff_randn",),
        "ff_randn_h_var": ("h, u <- u:ff_randn_h_var",),
        "narrow_uniform": ("h, u <- u:narrow_uniform",),
        "mid_uniform": ("h, u <- u:mid_uniform",),
        "wide_uniform": ("h, u <- u:wide_uniform",),
        "xl_uniform": ("h, u <- u:xl_uniform",),
    }

    tforms = {
        "to_zero": lambda n, k, v, rkg: jnp.zeros(n.shape.get()),
        "randn": lambda n, k, v, rkg: jax_random.normal(rkg(), n.shape.get()),
        "narrow_randn": lambda n, k, v, rkg: jax_random.normal(rkg(), n.shape.get()),
        "xav": lambda n, k, v, rkg: jax.random.uniform(
            rkg(),
            shape=(n.shape.get()),
            minval=-jnp.sqrt(6 / n.shape.get()[0]),
            maxval=jnp.sqrt(6 / n.shape.get()[0]),
        ),
        "ff_randn": lambda n, k, v, rkg: v + jax_random.normal(rkg(), n.shape.get()),
        "narrow_uniform": lambda n, k, v, rkg: jax.random.uniform(
            rkg(), shape=(n.shape.get()), minval=-1, maxval=1
        ),
        "mid_uniform": lambda n, k, v, rkg: jax.random.uniform(
            rkg(), shape=(n.shape.get()), minval=-3, maxval=3
        ),
        "wide_uniform": lambda n, k, v, rkg: jax.random.uniform(
            rkg(), shape=(n.shape.get()), minval=-5, maxval=5
        ),
        "xl_uniform": lambda n, k, v, rkg: jax.random.uniform(
            rkg(), shape=(n.shape.get()), minval=-10, maxval=10
        ),
        "ff_randn_h_var": ff_randn,
    }

    ruleset[pxc.STATUS.INIT] = ruleset[init]

    return ruleset, tforms


class LinearTranspose(px.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear
        # assert that layer doesnt have a bias
        assert self.linear.nn.bias is None

    def __call__(self, x):
        return self.linear.nn.weight.T @ x  # Apply the transpose of the weight


class Model(pxc.EnergyModule):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        nm_layers: int,
        activation: str,
        input_var=1.0,
        alpha_up=1.0,
        alpha_down=1.0,
        is_supervised=True,
        is_shared_weights=False,
        activity_init="ff",
        activity_init_kwargs={},
        is_post_activation=False,
        out_activation_down="tanh",
        out_activation_up=None,
    ) -> None:
        super().__init__()

        assert nm_layers >= 2

        def se_energy_input(vode, rkg: px.RandomKeyGenerator = px.RKG):
            """Squared error energy function derived from a Gaussian distribution."""
            e = vode.get("h") - vode.get("u")
            return 0.5 * (e * e) / input_var

        if activation == "relu":
            activation = jax.nn.relu
        elif activation == "tanh":
            activation = jax.nn.tanh
        elif activation == "silu":
            activation = jax.nn.silu
        elif activation == "l-relu":
            activation = jax.nn.leaky_relu
        elif activation == "h-tanh":
            activation = jax.nn.hard_tanh
        elif activation == "h-tanh":
            activation = jax.nn.hard_tanh
        else:
            activation = getattr(jax.nn, activation)

        self.activation = px.static(activation)

        # check if activity_init_kwargs has an h_var
        if "h_var" in activity_init_kwargs:
            h_var = activity_init_kwargs["h_var"]
        else:
            h_var = 0.0
        ruleset, tforms = set_init(activity_init, h_var=h_var)
        tforms_out, tforms_in = tforms, tforms

        self.vodes = (
            []
            if is_supervised
            else [
                pxc.Vode(
                    (input_dim,),
                    ruleset=ruleset,
                    tforms=tforms_in,
                )
            ]
        )
        self.vodes += (
            [
                pxc.Vode(
                    (input_dim,),
                    ruleset=ruleset,
                    tforms=tforms_in if is_supervised else tforms,
                )
            ]
            + [
                pxc.Vode(
                    (hidden_dim,),
                    ruleset=ruleset,
                    tforms=tforms,
                )
                for _ in range(nm_layers - 2)
            ]
            + [
                pxc.Vode(
                    (output_dim,),
                    energy_fn=se_energy_input,
                    ruleset=ruleset,
                    tforms=tforms_out,
                )
            ]
        )

        is_bias = not is_shared_weights

        # setup up down layers
        zero_prior = px.static(lambda x: x)
        self.layers_down = [] if is_supervised else [zero_prior]
        if nm_layers == 2:
            self.layers_down.append(pxnn.Linear(input_dim, output_dim, is_bias))
        else:
            self.layers_down += (
                [pxnn.Linear(input_dim, hidden_dim, is_bias)]
                + [
                    pxnn.Linear(hidden_dim, hidden_dim, is_bias)
                    for _ in range(nm_layers - 3)
                ]
                + [pxnn.Linear(hidden_dim, output_dim, is_bias)]
            )

        # setup up layers
        self.layers_up = []
        if not is_shared_weights:
            if nm_layers == 2:
                self.layers_up.append(pxnn.Linear(output_dim, input_dim, is_bias))
            else:
                self.layers_up += (
                    [pxnn.Linear(output_dim, hidden_dim, is_bias)]
                    + [
                        pxnn.Linear(hidden_dim, hidden_dim, is_bias)
                        for _ in range(nm_layers - 3)
                    ]
                    + [pxnn.Linear(hidden_dim, input_dim, is_bias)]
                )
            if not is_supervised:
                self.layers_up.append(px.static(lambda x: x))
        else:
            self.layers_up += [LinearTranspose(l) for l in reversed(self.layers_down)]

        self.out_activation_down = (
            px.static(lambda x: x)
            if out_activation_down is None
            else px.static(getattr(jax.nn, out_activation_down))
        )
        self.out_activation_up = (
            px.static(lambda x: 0.0 * x)
            if not is_supervised
            else (
                px.static(lambda x: x)
                if out_activation_up is None
                else px.static(getattr(jax.nn, out_activation_up))
            )
        )  # 0.0 because vodes[0] contains constant zeros so vodes[1] should be able to predict it

        self.vodes[0].h.frozen = True  # fixed latent state
        self.vodes[-1].h.frozen = True  # fixed data

        ##
        self.alpha_up = alpha_up
        self.alpha_down = alpha_down

        self.input_var = input_var
        self.key = px.RKG()

        self.is_supervised = px.static(lambda x=is_supervised: x)

        if is_post_activation:
            self.model_setup_post_activation()
        else:
            self.model_setup_pre_activation()

    def model_setup_pre_activation(self):
        # set transformation steps
        self.down = [self.vodes[0]]
        for idx, (v, l) in enumerate(zip(self.vodes[1:], self.layers_down)):
            if idx == len(self.layers_down) - 1:
                act_fn = self.out_activation_down
            elif (not self.is_supervised()) and idx == 0:
                act_fn = px.static(lambda x: x)
            else:
                act_fn = self.activation
            self.down += [l, act_fn, v]

        self.up = [self.vodes[-1]]
        for idx, (v, l) in enumerate(zip(self.vodes[-2::-1], self.layers_up)):
            if idx == len(self.layers_up) - 1:
                act_fn = self.out_activation_up
            elif (not self.is_supervised()) and idx + 2 == len(self.layers_down):
                act_fn = px.static(lambda x: x)
            else:
                act_fn = self.activation
            self.up += [l, act_fn, v]

    def model_setup_post_activation(self):
        # set transformation steps
        self.down = [self.vodes[0]]
        for idx, (v, l) in enumerate(zip(self.vodes[1:], self.layers_down)):
            if (not self.is_supervised()) and idx == 0:
                act_fn = px.static(lambda x: x)
            else:
                act_fn = self.activation
            self.down += [
                act_fn,
                l,
                *(
                    [self.out_activation_down]
                    if idx == len(self.layers_down) - 1
                    else []
                ),
                v,
            ]

        self.up = [self.vodes[-1]]
        for idx, (v, l) in enumerate(zip(self.vodes[-2::-1], self.layers_up)):
            if (not self.is_supervised()) and idx + 2 == len(
                self.layers_down
            ):  # or idx == 0:
                act_fn = px.static(lambda x: x)
            else:
                act_fn = self.activation
            self.up += [
                act_fn,
                l,
                *([self.out_activation_up] if idx == len(self.layers_up) - 1 else []),
                v,
            ]

    def model_down(self, x):
        input = self.down[0].get("h")
        for l in self.down:
            input = l(input)
        return self.down[-1].get("u")

    def model_up(self, y):
        input = self.up[0].get("h")
        for l in self.up:
            input = l(input)
        return self.up[-1].get("u")

    def model_down_fp(self, input, noise_var: float = 0.0):
        for l in self.down[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_down,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def model_up_fp(self, input, noise_var: float = 0.0):
        for l in self.up[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def hidden_to_input_fp(self, input, noise_var: float = 0.0):
        for l in self.down[4:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def input_to_hidden_fp(self, input, noise_var: float = 0.0):
        for l in self.up[1:-4]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def __call__(self, x, y, is_up_initialisation: bool = True):
        if x is not None:
            self.down[0].set("h", x)
        if y is not None:
            self.up[0].set("h", y)

        if is_up_initialisation:
            output = self.model_up(y)
        else:
            output = self.model_down(x)

        if x is not None:  # need reset if initialisation has overwritten the values
            self.down[0].set("h", x)
        if y is not None:
            self.up[0].set("h", y)
        return output


class SubModel(pxc.EnergyModule):
    def __init__(self, model: Model, top_layer_idx=None, bottom_layer_idx=None):
        super().__init__()

        self.down = model.down[top_layer_idx:bottom_layer_idx]
        self.vodes = [v for v in self.down if isinstance(v, pxc.Vode)]
        self.up = model.up[::-1][top_layer_idx:bottom_layer_idx][::-1]

        assert (
            isinstance(self.down[0], pxc.Vode)
            and isinstance(self.down[-1], pxc.Vode)
            and isinstance(self.up[0], pxc.Vode)
            and isinstance(self.up[-1], pxc.Vode)
        )

        self.alpha_down = model.alpha_down
        self.alpha_up = model.alpha_up

    def model_down(self, x):
        input = self.down[0].get("h")
        for l in self.down:
            input = l(input)
        return self.down[-1].get("u")

    def model_up(self, y):
        input = self.up[0].get("h")
        for l in self.up:
            input = l(input)
        return self.up[-1].get("u")

    def model_down_fp(self, input, noise_var: float = 0.0):
        for l in self.down[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_down,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def model_up_fp(self, input, noise_var: float = 0.0):
        for l in self.up[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def hidden_to_input_fp(self, input, noise_var: float = 0.0):
        for l in self.down[4:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def input_to_hidden_fp(self, input, noise_var: float = 0.0):
        for l in self.up[1:-4]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def __call__(self, x, y, is_up_initialisation: bool = True):
        if x is not None:
            self.down[0].set("h", x)
        if y is not None:
            self.up[0].set("h", y)

        if is_up_initialisation:
            output = self.model_up(y)
        else:
            output = self.model_down(x)

        if x is not None:  # need reset if initialisation has overwritten the values
            self.down[0].set("h", x)
        if y is not None:
            self.up[0].set("h", y)
        return output


class AddLatent(pxc.EnergyModule):
    def __init__(
        self,
        model: Model,
        latent_dim,
        latent_init,
        latent_var=None,
        is_shared_weights=False,
        is_stop_gradient=False,
        is_post_activation=False,
    ) -> None:
        super().__init__()

        self.up = model.up
        self.down = model.down
        self.vodes = model.vodes
        self.activation = model.activation

        is_bias = not is_shared_weights

        ruleset = {
            "ff": ("h, u <- u",),
            "zero": ("h, u <- u:to_zero",),
            "randn": ("h, u <- u:randn",),
            "xav": ("h, u <- u:xav",),
            "ff_randn": ("h, u <- u:ff_randn",),
            "narrow_uniform": ("h, u <- u:narrow_uniform",),
            "mid_uniform": ("h, u <- u:mid_uniform",),
            "wide_uniform": ("h, u <- u:wide_uniform",),
            "xl_uniform": ("h, u <- u:xl_uniform",),
        }

        tforms = {
            "to_zero": lambda n, k, v, rkg: jnp.zeros(n.shape.get()),
            "randn": lambda n, k, v, rkg: jax_random.normal(rkg(), n.shape.get()),
            "xav": lambda n, k, v, rkg: jax.random.uniform(
                rkg(),
                shape=(n.shape.get()),
                minval=-jnp.sqrt(6 / n.shape.get()[0]),
                maxval=jnp.sqrt(6 / n.shape.get()[0]),
            ),
            "ff_randn": lambda n, k, v, rkg: v
            + jax_random.normal(rkg(), n.shape.get()),
            "narrow_uniform": lambda n, k, v, rkg: jax.random.uniform(
                rkg(), shape=(n.shape.get()), minval=-1, maxval=1
            ),
            "mid_uniform": lambda n, k, v, rkg: jax.random.uniform(
                rkg(), shape=(n.shape.get()), minval=-3, maxval=3
            ),
            "wide_uniform": lambda n, k, v, rkg: jax.random.uniform(
                rkg(), shape=(n.shape.get()), minval=-5, maxval=5
            ),
            "xl_uniform": lambda n, k, v, rkg: jax.random.uniform(
                rkg(), shape=(n.shape.get()), minval=-10, maxval=10
            ),
        }

        def se_energy_latent(vode, rkg: px.RandomKeyGenerator = px.RKG):
            """Squared error energy function derived from a Gaussian distribution."""
            var = jax.lax.cond(
                jnp.sum(vode.get("u")) != 0, lambda: latent_var, lambda: 1.0
            )
            e = vode.get("h") - vode.get("u")
            return 0.5 * (e * e) / var

        hidden_dim = np.prod(self.vodes[1].shape)
        ruleset[pxc.STATUS.INIT] = ruleset[latent_init]
        self.latent_vode = pxc.Vode(
            (latent_dim,),
            ruleset=ruleset,
            tforms=tforms,
            energy_fn=se_energy_latent if latent_var is not None else se_energy,
        )
        self.latent_dim = latent_dim
        self.latent_limit = jnp.sqrt(6 / latent_dim)

        self.latent_layer_up = pxnn.Linear(hidden_dim, latent_dim, is_bias)
        self.latent_layer_down = (
            pxnn.Linear(latent_dim, hidden_dim, is_bias)
            if not is_shared_weights
            else LinearTranspose(self.latent_layer_up)
        )
        self.bias_latent = px.static(lambda x: jnp.zeros((latent_dim,)))

        self.latent_vode.h.latent = True

        self.alpha_down = model.alpha_down
        self.alpha_up = model.alpha_up

        self.grad_transform = (
            px.static(lambda x: x)
            if not is_stop_gradient
            else px.static(jax.lax.stop_gradient)
        )

        def combination_fn_pre(input, latent, l, ld):
            return l(input + ld(latent))

        def combination_fn_post(input, latent, l, ld, activation):
            return activation(l(input) + ld(latent))

        if is_post_activation:
            self.combination_fn = px.static(
                partial(combination_fn_post, activation=self.activation)
            )
            self.combination_idx = px.static(3)
        else:
            self.combination_fn = px.static(combination_fn_pre)
            self.combination_idx = px.static(2)

    def __call__(self, x, y, is_up_initialisation: bool = True, latent=None):
        if not is_up_initialisation:
            if latent is None:
                self.latent_vode.set("h", jnp.zeros(self.latent_vode.shape))
            else:
                self.latent_vode.set("h", latent)

        if x is not None:
            self.vodes[0].set("h", x)
        if y is not None:
            self.vodes[-1].set("h", y)

        if is_up_initialisation:
            output = self.model_up(y)
        else:
            output = self.model_down(x, latent)

        if x is not None:  # need reset if initialisation has overwritten the values
            self.vodes[0].set("h", x)
        if y is not None:
            self.vodes[-1].set("h", y)
        if latent is not None:
            self.latent_vode.set("h", latent)

        if isinstance(output, tuple):
            return output[0]
        return output

    def model_down(self, x, latent=None):
        input = self.down[0].get("h")
        if latent is None:
            latent = self.bias_latent(jnp.ones((1,)))
        latent = self.latent_vode(latent)
        for idx, l in enumerate(self.down):
            if idx == self.combination_idx:
                input = self.combination_fn(input, latent, l, self.latent_layer_down)
            else:
                input = l(input)
        return self.down[-1].get("u")

    def model_up(self, y):
        input = self.up[0].get("h")
        for l in self.up[:-3]:
            input = l(input)
        self.up[-1](self.up[-2](self.up[-3](input)))
        self.latent_vode(self.latent_layer_up(self.grad_transform(input)))
        return self.vodes[0].get("u"), self.latent_vode.get("u")

    def hidden_to_input_fp(self, input, noise_var: float = 0.0):
        for l in self.down[4:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def model_down_fp(self, input, noise_var: float = 0.0, latent=None):
        if latent is None:
            # latent = jax_random.normal(px.RKG(), self.latent_vode.shape.get())
            latent = jnp.zeros(self.latent_vode.shape.get())

        input = self.down[1](input)
        input = self.combination_fn(input, latent, self.down[2], self.latent_layer_down)
        if len(self.down) > 4:
            input = self.hidden_to_input_fp(input, noise_var)
        return input

    def model_up_fp(self, input, noise_var: float = 0.0):
        for l in self.up[1:-3]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        output = self.up[-2](self.up[-3](input))
        latent = self.latent_layer_up(input)
        return output, latent


class OModel(pxc.EnergyModule):
    def __init__(
        self,
        model_up: Model,
        model_down: Model,
        alpha_up=1.0,
        alpha_down=1.0,
    ) -> None:
        super().__init__()

        self.up = model_up.up
        self.down = model_down.down

        # make the models share the same input and output vodes
        self.down[-1] = self.up[0]  # keep last vode of up
        self.up[-1] = self.down[0]  # keep first vode of down

        self.alpha_up = alpha_up
        self.alpha_down = alpha_down

        # reset all
        self.vodes = []
        for l in self.down[:-1] + self.up[-2::-1]:
            if isinstance(l, pxc.Vode):
                l.h = VodeParam()
                l.cache = VodeParam.Cache()
                self.vodes.append(l)

    def init_energy_to_zero(self):
        for v in self.vodes:
            v(v.get("h"))

    def model_down(self, x):
        self.init_energy_to_zero()
        input = self.down[0].get("h")
        for l in self.down:
            input = l(input)
        return self.down[-1].get("u")

    def model_up(self, y):
        self.init_energy_to_zero()
        input = self.up[0].get("h")
        for l in self.up:
            input = l(input)
        return self.up[-1].get("u")

    def model_down_no_init(self, x):
        input = self.down[0].get("h")
        for l in self.down:
            input = l(input)
        return self.down[-1].get("u")

    def model_up_no_init(self, y):
        input = self.up[0].get("h")
        for l in self.up:
            input = l(input)
        return self.up[-1].get("u")

    def model_down_fp(self, input, noise_var: float = 0.0):
        for l in self.down[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_down,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def model_up_fp(self, input, noise_var: float = 0.0):
        for l in self.up[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def hidden_to_input_fp(self, input, noise_var: float = 0.0):
        for l in self.down[4:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def input_to_hidden_fp(self, input, noise_var: float = 0.0):
        for l in self.up[1:-4]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def __call__(self, x, y, is_up_initialisation: bool = True):
        if x is not None:
            self.down[0].set("h", x)
        if y is not None:
            self.up[0].set("h", y)

        if is_up_initialisation:
            output = self.model_up_no_init(y)
            output = self.model_down_no_init(x)
        else:
            output = self.model_down_no_init(x)
            output = self.model_up_no_init(y)

        if x is not None:  # need reset if initialisation has overwritten the values
            self.down[0].set("h", x)
        if y is not None:
            self.up[0].set("h", y)
        return output


class Model_w_latent(Model):
    def __init__(
        self,
        latent_dim,
        latent_init,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        nm_layers: int,
        activation: str,
        input_var=1.0,
        alpha_up=1.0,
        alpha_down=1.0,
        is_supervised=True,
        is_shared_weights=False,
        activity_init="ff",
        activity_init_kwargs={},
        output_layer_init=None,
        input_layer_init=None,
        latent_var=None,
    ) -> None:
        super().__init__(
            input_dim,
            hidden_dim,
            output_dim,
            nm_layers,
            activation,
            input_var,
            alpha_up,
            alpha_down,
            is_supervised,
            is_shared_weights,
            activity_init,
            activity_init_kwargs,
        )

        is_bias = not is_shared_weights

        ruleset = {
            "ff": ("h, u <- u",),
            "zero": ("h, u <- u:to_zero",),
            "ones": ("h, u <- u:to_ones",),
            "min_ones": ("h, u <- u:to_min_ones",),
            "tenth": ("h, u <- u:to_tenth",),
            "randn": ("h, u <- u:randn",),
            "xav": ("h, u <- u:xav",),
            "ff_randn": ("h, u <- u:ff_randn",),
            "narrow_uniform": ("h, u <- u:narrow_uniform",),
            "mid_uniform": ("h, u <- u:mid_uniform",),
            "wide_uniform": ("h, u <- u:wide_uniform",),
            "xl_uniform": ("h, u <- u:xl_uniform",),
            "latent_gen": ("h, u <- u:mid_uniform",),
        }

        tforms = {
            "to_zero": lambda n, k, v, rkg: jnp.zeros(n.shape.get()),
            "to_ones": lambda n, k, v, rkg: jnp.ones(n.shape.get()),
            "to_min_ones": lambda n, k, v, rkg: jnp.ones(n.shape.get()),
            "to_tenth": lambda n, k, v, rkg: 0.1 * jnp.ones(n.shape.get()),
            "randn": lambda n, k, v, rkg: jax_random.normal(rkg(), n.shape.get()),
            "xav": lambda n, k, v, rkg: jax.random.uniform(
                rkg(),
                shape=(n.shape.get()),
                minval=-jnp.sqrt(6 / n.shape.get()[0]),
                maxval=jnp.sqrt(6 / n.shape.get()[0]),
            ),
            "ff_randn": lambda n, k, v, rkg: v
            + jax_random.normal(rkg(), n.shape.get()),
            "narrow_uniform": lambda n, k, v, rkg: jax.random.uniform(
                rkg(), shape=(n.shape.get()), minval=-1, maxval=1
            ),
            "mid_uniform": lambda n, k, v, rkg: jax.random.uniform(
                rkg(), shape=(n.shape.get()), minval=-3, maxval=3
            ),
            "wide_uniform": lambda n, k, v, rkg: jax.random.uniform(
                rkg(), shape=(n.shape.get()), minval=-5, maxval=5
            ),
            "xl_uniform": lambda n, k, v, rkg: jax.random.uniform(
                rkg(), shape=(n.shape.get()), minval=-10, maxval=10
            ),
        }

        def se_energy_latent(vode, rkg: px.RandomKeyGenerator = px.RKG):
            """Squared error energy function derived from a Gaussian distribution."""
            var = jax.lax.cond(
                jnp.sum(vode.get("u")) != 0, lambda: latent_var, lambda: 1.0
            )
            e = vode.get("h") - vode.get("u")
            return 0.5 * (e * e) / var

        ruleset[pxc.STATUS.INIT] = ruleset[latent_init]
        self.latent_vode = pxc.Vode(
            (latent_dim,),
            ruleset=ruleset,
            tforms=tforms,
            energy_fn=se_energy_latent if latent_var is not None else se_energy,
        )
        self.latent_dim = latent_dim
        self.latent_limit = jnp.sqrt(6 / latent_dim)

        self.latent_layer_up = pxnn.Linear(hidden_dim, latent_dim, is_bias)
        self.latent_layer_down = (
            pxnn.Linear(latent_dim, hidden_dim, is_bias)
            if not is_shared_weights
            else LinearTranspose(self.latent_layer_up)
        )
        self.bias_latent = px.static(lambda x: jnp.zeros((latent_dim,)))

        # sanitise initialisation of vodes
        vode_rset, input_rset, output_rset = (
            copy.deepcopy(ruleset),
            copy.deepcopy(ruleset),
            copy.deepcopy(ruleset),
        )

        vode_rset[pxc.STATUS.INIT] = vode_rset[activity_init]
        input_rset[pxc.STATUS.INIT] = (
            input_rset[activity_init]
            if input_layer_init is None
            else input_rset[input_layer_init]
        )
        output_rset[pxc.STATUS.INIT] = (
            output_rset[activity_init]
            if output_layer_init is None
            else output_rset[output_layer_init]
        )
        output_rset["latent_gen"] = ruleset["narrow_uniform"]

        for v in self.vodes:
            v.ruleset = Ruleset(vode_rset, tforms)
        self.vodes[-1].ruleset = Ruleset(output_rset, tforms)
        self.vodes[0].ruleset = Ruleset(input_rset, tforms)

    def __call__(self, x, y, is_up_initialisation: bool = True):
        if not is_up_initialisation:
            self.latent_vode(jnp.zeros(self.latent_vode.shape))
        output = super().__call__(x, y, is_up_initialisation=is_up_initialisation)
        if x is not None:
            self.vodes[0].set("h", x)
        if y is not None:
            self.vodes[-1].set("h", y)

        if is_up_initialisation:
            output = self.model_up(y)
        else:
            output = self.model_down(x)

        if x is not None:  # need reset if initialisation has overwritten the values
            self.vodes[0].set("h", x)
        if y is not None:
            self.vodes[-1].set("h", y)

        if isinstance(output, tuple):
            return output[0]
        return output

    def model_down(self, x):
        input = self.vodes[0](
            self.vodes[0].get("h")
        )  # have zero energy, rely on preinitialisation
        latent = self.latent_vode(self.bias_latent(jnp.ones((1,))))
        for idx, (v, l) in enumerate(zip(self.vodes[1:], self.layers_down)):
            act_fn = (
                self.activation
                if idx + 1 < len(self.layers_down)
                else self.out_activation_down
            )
            if idx == 0:
                input = act_fn(l(input) + self.latent_layer_down(latent))
            else:
                input = act_fn(l(input))
            input = v(input)
        return self.vodes[-1].get("u")

    def model_up(self, y):
        input = self.vodes[-1](
            self.vodes[-1].get("h")
        )  # have zero energy, rely on preinitialisation
        for v, l in zip(self.vodes[-2:0:-1], self.layers_up[:-1]):
            input = v(self.activation(l(input)))
        self.vodes[0](self.out_activation_up(self.layers_up[-1](input)))
        self.latent_vode(self.latent_layer_up(input))
        return self.vodes[0].get("u"), self.latent_vode.get("u")

    def model_down_unsupervised(self, x):
        self.vodes[0](
            self.vodes[0].get("h")
        )  # have zero energy, rely on preinitialisation
        input = self.latent_vode(self.bias_latent(jnp.ones((1,))))
        for idx, (v, l) in enumerate(zip(self.vodes[1:], self.layers_down)):
            act_fn = (
                self.activation
                if idx + 1 < len(self.layers_down)
                else self.out_activation_down
            )
            if idx == 0:
                input = act_fn(self.latent_layer_down(input))
            else:
                input = act_fn(l(input))
            input = v(input)
        return self.vodes[-1].get("u")

    def model_up_unsupervised(self, y):
        input = self.vodes[-1](
            self.vodes[-1].get("h")
        )  # have zero energy, rely on preinitialisation
        for v, l in zip(self.vodes[-2:0:-1], self.layers_up[:-1]):
            input = v(self.activation(l(input)))
        self.vodes[0](self.vodes[0].get("h"))
        self.latent_vode(self.latent_layer_up(input))
        return self.vodes[0].get("u"), self.latent_vode.get("u")

    def hidden_to_input_fp(self, input, noise_var: float = 0.0):
        for l in self.layers_down[1:-1]:
            input = self.activation(l(input))
            if noise_var > 0:
                input = sample_multivariate_Gauss(
                    input, noise_var * jnp.eye(len(input)) / self.alpha_down, px.RKG()
                )
        input = self.out_activation_down(self.layers_down[-1](input))
        return input

    def model_down_fp(self, input, noise_var: float = 0.0, latent=None):
        if latent is None:
            latent = jnp.zeros(self.latent_vode.shape.get())
        act_fn = (
            self.activation if 1 != len(self.layers_down) else self.out_activation_down
        )
        input = act_fn(self.layers_down[0](input) + self.latent_layer_down(latent))  #
        if len(self.layers_down) > 1:
            input = self.hidden_to_input_fp(input, noise_var)
        return input

    def model_down_fp_unsupervised(self, latent, noise_var: float = 0.0):
        if latent is None:
            latent = jnp.zeros(self.latent_vode.shape.get())
        act_fn = (
            self.activation if 1 != len(self.layers_down) else self.out_activation_down
        )
        input = act_fn(self.latent_layer_down(latent))  #
        if len(self.layers_down) > 1:
            input = self.hidden_to_input_fp(input, noise_var)
        return input

    def model_up_fp(self, y, noise_var: float = 0.0):
        input = y
        for l in self.layers_up[:-1]:
            input = self.activation(l(input))
            if noise_var > 0:
                input = sample_multivariate_Gauss(
                    input, noise_var * jnp.eye(len(input)) / self.alpha_up, px.RKG()
                )
        output = self.out_activation_up(self.layers_up[-1](input))
        latent = self.latent_layer_up(input)
        return output, latent


class Model_graph(pxc.EnergyModule):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        nm_layers: int,
        activation: str,
        input_var=1.0,
        alpha_up=1.0,
        alpha_down=1.0,
        is_supervised=True,
        is_shared_weights=False,
        activity_init="ff",
        activity_init_kwargs={},
        is_post_activation=False,
        out_activation_down="tanh",
        out_activation_up=None,
    ) -> None:
        super().__init__()

        assert nm_layers >= 2

        assert is_post_activation == False, "Not implemented yet"

        def se_energy_input(vode, rkg: px.RandomKeyGenerator = px.RKG):
            """Squared error energy function derived from a Gaussian distribution."""
            e = vode.get("h") - vode.get("u")
            return 0.5 * (e * e) / input_var

        if activation == "relu":
            activation = jax.nn.relu
        elif activation == "tanh":
            activation = jax.nn.tanh
        elif activation == "silu":
            activation = jax.nn.silu
        elif activation == "l-relu":
            activation = jax.nn.leaky_relu
        elif activation == "h-tanh":
            activation = jax.nn.hard_tanh
        elif activation == "h-tanh":
            activation = jax.nn.hard_tanh
        else:
            activation = getattr(jax.nn, activation)

        self.activation = px.static(activation)

        ruleset, tforms = set_init(activity_init)
        tforms_out, tforms_in = tforms, tforms

        self.vodes = (
            []
            if is_supervised
            else [
                pxc.Vode(
                    (input_dim,),
                    ruleset=ruleset,
                    tforms=tforms_in,
                )
            ]
        )
        self.vodes += (
            [
                pxc.Vode(
                    (input_dim,),
                    ruleset=ruleset,
                    tforms=tforms_in if is_supervised else tforms,
                )
            ]
            + [
                pxc.Vode(
                    (hidden_dim,),
                    ruleset=ruleset,
                    tforms=tforms,
                )
                for _ in range(nm_layers - 2)
            ]
            + [
                pxc.Vode(
                    (output_dim,),
                    energy_fn=se_energy_input,
                    ruleset=ruleset,
                    tforms=tforms_out,
                )
            ]
        )

        is_bias = not is_shared_weights

        # setup up down layers
        zero_prior = px.static(lambda x: x)
        self.layers_down = [] if is_supervised else [zero_prior]
        if nm_layers == 2:
            self.layers_down.append(pxnn.Linear(input_dim, output_dim, is_bias))
        else:
            self.layers_down += (
                [pxnn.Linear(input_dim, hidden_dim, is_bias)]
                + [
                    pxnn.Linear(hidden_dim, hidden_dim, is_bias)
                    for _ in range(nm_layers - 3)
                ]
                + [pxnn.Linear(hidden_dim, output_dim, is_bias)]
            )

        # setup up layers
        self.layers_up = []
        if not is_shared_weights:
            if nm_layers == 2:
                self.layers_up.append(pxnn.Linear(output_dim, input_dim, is_bias))
            else:
                self.layers_up += (
                    [pxnn.Linear(output_dim, hidden_dim, is_bias)]
                    + [
                        pxnn.Linear(hidden_dim, hidden_dim, is_bias)
                        for _ in range(nm_layers - 3)
                    ]
                    + [pxnn.Linear(hidden_dim, input_dim, is_bias)]
                )
            if not is_supervised:
                self.layers_up.append(px.static(lambda x: x))
        else:
            self.layers_up += [LinearTranspose(l) for l in reversed(self.layers_down)]

        self.out_activation_down = (
            px.static(lambda x: x)
            if out_activation_down is None
            else px.static(getattr(jax.nn, out_activation_down))
        )
        self.out_activation_up = (
            px.static(lambda x: 0.0 * x)
            if not is_supervised
            else (
                px.static(lambda x: x)
                if out_activation_up is None
                else px.static(getattr(jax.nn, out_activation_up))
            )
        )  # 0.0 because vodes[0] contains constant zeros so vodes[1] should be able to predict it

        self.vodes[0].h.frozen = True  # fixed latent state
        self.vodes[-1].h.frozen = True  # fixed data

        ##
        self.alpha_up = alpha_up
        self.alpha_down = alpha_down

        self.input_var = input_var
        self.key = px.RKG()

        self.is_supervised = px.static(lambda x=is_supervised: x)

        # set transformation steps
        self.down = [self.vodes[0]]
        for idx, (v, l) in enumerate(zip(self.vodes[1:], self.layers_down)):
            if idx == len(self.layers_down) - 1:
                act_fn = self.out_activation_down
            elif (not self.is_supervised()) and idx == 0:
                act_fn = px.static(lambda x: x)
            else:
                act_fn = self.activation
            self.down += [l, act_fn, v]

        self.up = [self.vodes[-1]]
        for idx, (v, l) in enumerate(zip(self.vodes[-2::-1], self.layers_up)):
            if idx == len(self.layers_up) - 1:
                act_fn = self.out_activation_up
            elif (not self.is_supervised()) and idx + 2 == len(self.layers_down):
                act_fn = px.static(lambda x: x)
            else:
                act_fn = self.activation
            self.up += [l, act_fn, v]

    def model(self, x, y):
        # for each layer set input transformation
        self.vodes[0](
            self.out_activation_up(self.layers_up[-1](self.vodes[1].get("h")))
        )
        for v, v_prev, v_next, ld, lu in zip(
            self.vodes[1:-1],
            self.vodes[:-2],
            self.vodes[2:],
            self.layers_down[:-1],
            self.layers_up[-2::-1],
        ):
            v(self.activation(ld(v_prev.get("h")) + lu(v_next.get("h"))))
        self.vodes[-1](
            self.out_activation_down(self.layers_down[-1](self.vodes[-2].get("h")))
        )
        return self.vodes[-1].get("u"), self.vodes[0].get("u")

    def model_down(self, x):
        input = self.down[0].get("h")
        for l in self.down:
            input = l(input)
        return self.down[-1].get("u")

    def model_up(self, y):
        input = self.up[0].get("h")
        for l in self.up:
            input = l(input)
        return self.up[-1].get("u")

    def model_down_fp(self, input, noise_var: float = 0.0):
        for l in self.down[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_down,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def model_up_fp(self, input, noise_var: float = 0.0):
        for l in self.up[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_up,
                        px.RKG(),
                    )
            else:
                input = l(input)
        return input

    def __call__(self, x, y, is_up_initialisation: bool = True):
        if x is not None:
            self.down[0].set("h", x)
        if y is not None:
            self.up[0].set("h", y)

        if is_up_initialisation:
            output = self.model_up(y)
        else:
            output = self.model_down(x)

        if x is not None:  # need reset if initialisation has overwritten the values
            self.down[0].set("h", x)
        if y is not None:
            self.up[0].set("h", y)
        return output


class LatentHead(pxc.EnergyModule):
    def __init__(
        self,
        model_w_latent: Model_w_latent,
    ) -> None:
        super().__init__()

        self.alpha_up = model_w_latent.alpha_up
        self.alpha_down = model_w_latent.alpha_down

        self.activation = model_w_latent.activation
        self.out_activation_up = model_w_latent.out_activation_up

        self.vodes = model_w_latent.vodes[0:2]

        self.label_vode = model_w_latent.vodes[0]
        self.hidden_vode = model_w_latent.vodes[1]

        self.lab_layer_down = model_w_latent.layers_down[0]
        self.lab_layer_up = model_w_latent.layers_up[-1]

        self.latent_vode = model_w_latent.latent_vode
        self.latent_dim = model_w_latent.latent_dim
        self.latent_limit = model_w_latent.latent_limit

        self.latent_layer_up = model_w_latent.latent_layer_up
        self.latent_layer_down = model_w_latent.latent_layer_down
        self.bias_latent = model_w_latent.bias_latent

    def __call__(self, x, y, is_up_initialisation: bool = True):
        if not is_up_initialisation:
            self.latent_vode(0.5 * jnp.ones(self.latent_vode.shape))

        if x is not None:
            self.label_vode.set("h", x)
        if y is not None:
            self.hidden_vode.set("h", y)

        if is_up_initialisation:
            output = self.model_up(y)
        else:
            output = self.model_down(x)

        if x is not None:  # need reset if initialisation has overwritten the values
            self.label_vode.set("h", x)
        if y is not None:
            self.hidden_vode.set("h", y)
        return output

    def model_down(self, x):
        input = self.label_vode(
            self.label_vode.get("h")
        )  # have zero energy, rely on preinitialisation
        latent = self.latent_vode(self.bias_latent(jnp.ones((1,))))  # L2 regularisation
        input = self.hidden_vode(
            self.activation(self.lab_layer_down(input) + self.latent_layer_down(latent))
        )  # add bias
        return self.hidden_vode.get("u")

    def model_up(self, y):
        input = self.hidden_vode(self.hidden_vode.get("h"))
        self.label_vode(self.out_activation_up(self.lab_layer_up(input)))
        self.latent_vode(self.out_activation_up(self.latent_layer_up(input)))
        return self.label_vode.get("u"), self.latent_vode.get("u")


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0,), out_axes=0
)
def fp_down_latents(x, *, model: Model, noise_var: float = 0.0):
    return model.hidden_to_input_fp(x, noise_var)


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0,), out_axes=0
)
def fp_up_latents(y, *, model: Model, noise_var: float = 0.0):
    return model.input_to_hidden_fp(y, noise_var)


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0,), out_axes=0
)
def fp_up_latents(y, *, model: Model, noise_var: float = 0.0):
    return model.input_to_hidden_fp(y, noise_var)


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0,), out_axes=0
)
def fp_down(x, *, model: Model, noise_var: float = 0.0):
    return model.model_down_fp(x, noise_var)


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0,), out_axes=0
)
def fp_up(y, *, model: Model, noise_var: float = 0.0):
    return model.model_up_fp(y, noise_var)


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0, 0), out_axes=0
)
def fp_down_w_latent(x, latent, *, model: Model_w_latent, noise_var: float = 0.0):
    return model.model_down_fp(x, noise_var, latent)


@pxf.jit(static_argnums=(2))
def fp_on_batch(
    x: jax.Array,
    y: jax.Array,
    noise_var: float,
    *,
    model: Model,
):
    y_out = None
    if x is not None:
        y_out = fp_down(x, model=model, noise_var=noise_var)

    x_out = None
    if y is not None:
        x_out = fp_up(y, model=model, noise_var=noise_var)

    return x_out, y_out


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0, 0), out_axes=0
)
def initialisation(x, y, *, model: Model, is_up_initialisation: bool = True):
    return model(x, y, is_up_initialisation=is_up_initialisation)


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0, 0),
    out_axes=0,
)
def initialisation_latent(
    x, y, latent, *, model: Model, is_up_initialisation: bool = True
):
    return model(x, y, is_up_initialisation=is_up_initialisation, latent=latent)


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(None, 0),
    axis_name="batch",
)
def energy_down(x, y, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        y_down = model.model_down(x)
        energy_down = model.energy() * model.alpha_down
    return jax.lax.pmean(energy_down, "batch"), y_down


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(None, 0),
    axis_name="batch",
)
def energy_up(x, y, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        x_up = model.model_up(y)
        energy_up = model.energy() * model.alpha_up
    return jax.lax.pmean(energy_up, "batch"), x_up


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(None, 0, 0),
    axis_name="batch",
)
def energy(x, y, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        x_up = model.model_up(y)
        energy_up = model.energy() * model.alpha_up
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        y_down = model.model_down(x)
        energy_down = model.energy() * model.alpha_down
    return jax.lax.pmean(energy_up + energy_down, "batch"), y_down, x_up


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(None, 0, 0),
    axis_name="batch",
)
def energy_weights(x, y, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        x_up = model.model_up(y)
        energy_up = model.energy()
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        y_down = model.model_down(x)
        energy_down = model.energy()
    return jax.lax.pmean(energy_up + energy_down, "batch"), y_down, x_up


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(None, 0, 0),
    axis_name="batch",
)
def energy_graph(x, y, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        y_down, x_up = model.model(x, y)
        local_energy = model.energy()
    return jax.lax.pmean(local_energy, "batch"), y_down, x_up


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0, 0),
    out_axes=(None, 0, 0),
    axis_name="batch",
)
def energy_latent(x, y, latent, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        x_up = model.model_up(y)
        energy_up = model.energy() * model.alpha_up
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        y_down = model.model_down(x, latent)
        energy_down = model.energy() * model.alpha_down
    return jax.lax.pmean(energy_up + energy_down, "batch"), y_down, x_up


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(None, None),
    axis_name="batch",
)
def energy_per_stream(x, y, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        x_up = model.model_up(y)
        energy_up = model.energy() * model.alpha_up
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        y_down = model.model_down(x)
        energy_down = model.energy() * model.alpha_down
    return jax.lax.pmean(energy_up, "batch"), jax.lax.pmean(energy_down, "batch")


def pmedian(x, axis_name):
    """
    Compute the median across devices in a parallel fashion.
    """
    # Gather all values from all devices
    all_x = jax.lax.all_gather(x, axis_name=axis_name)

    # Flatten the gathered array
    all_x_flat = all_x.reshape(-1)

    # Sort the combined values
    sorted_x = jnp.sort(all_x_flat)

    # Compute the median
    n = sorted_x.size
    if n % 2 == 0:
        # Even number of elements: average the two middle values
        median = (sorted_x[n // 2 - 1] + sorted_x[n // 2]) / 2.0
    else:
        # Odd number of elements: take the middle value
        median = sorted_x[n // 2]

    return median


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(None, None),
    axis_name="batch",
)
def energy_per_stream_median(x, y, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        x_up = model.model_up(y)
        energy_up = model.energy() * model.alpha_up
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        y_down = model.model_down(x)
        energy_down = model.energy() * model.alpha_down
    return pmedian(energy_up, "batch"), pmedian(energy_down, "batch")


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(0, 0, 0),
)
def energy_per_data(x, y, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        x_up = model.model_up(y)
        energy_up = model.energy() * model.alpha_up
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        y_down = model.model_down(x)
        energy_down = model.energy() * model.alpha_down
    return energy_down + energy_up, energy_up, energy_down


def setup_infer_on_batch(
    local_energy, is_free_latents=False, init=pxc.STATUS.INIT, optim_h_latent=None
):
    @pxf.jit(static_argnums=(0, 3, 4, 7))
    def infer_on_batch(
        T: int,
        x: jax.Array,
        y: jax.Array,
        is_up_initialisation: bool,
        mode: int,
        *,
        model: Model,
        optim_h: pxu.Optim,
        init: str = init,
    ):
        mode_mapping = {
            0: "constrained",
            1: "label-only",
            2: "data-only",
            3: "unconstrained",
        }
        mode = mode_mapping.get(mode, mode)

        def h_step(i, x, y, *, model, optim_h):
            with pxu.step(model, clear_params=pxc.VodeParam.Cache):
                (e, pred), g = pxf.value_and_grad(
                    pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                    has_aux=True,
                )(local_energy)(x, y, model=model)
            optim_h.step(model, g["model"], True)  #
            return (x, y), None

        model.train()

        if mode == "constrained":
            model.vodes[0].h.frozen = True
            model.vodes[-1].h.frozen = True
        elif mode == "label-only":
            model.vodes[0].h.frozen = True
            model.vodes[-1].h.frozen = False
            if not is_up_initialisation:
                y = None
            else:
                assert y is not None
            # is_up_initialisation = False
        elif mode == "data-only":
            model.vodes[0].h.frozen = False
            model.vodes[-1].h.frozen = True
            if is_up_initialisation:
                x = None
            else:
                assert x is not None
            # is_up_initialisation = True
        elif mode == "unconstrained":
            model.vodes[0].h.frozen = False
            model.vodes[-1].h.frozen = False
            x = None
            y = None

        # Init step
        with pxu.step(model, init, clear_params=pxc.VodeParam.Cache):
            initialisation(x, y, model=model, is_up_initialisation=is_up_initialisation)

        optim_h.init(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True))(model))

        # Inference steps
        pxf.scan(h_step, xs=jax.numpy.arange(T))(x, y, model=model, optim_h=optim_h)

        optim_h.clear()

        # restore frozen states
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = True
        return model.vodes[0].get("h"), model.vodes[-1].get("h")

    @pxf.jit(static_argnums=(0, 3, 4, 8))
    def infer_on_batch_with_free_latent(
        T: int,
        x: jax.Array,
        y: jax.Array,
        is_up_initialisation: bool,
        mode: int,
        latent: jax.Array = None,
        *,
        model: Model,
        optim_h: pxu.Optim,
        optim_h_latent: pxu.Optim,
        init: str = init,
    ):
        mode_mapping = {
            0: "constrained",
            1: "label-only",
            2: "data-only",
            3: "unconstrained",
        }
        mode = mode_mapping.get(mode, mode)

        def h_step(i, x, y, *, model, optim_h, optim_h_latent):
            with pxu.step(model, clear_params=pxc.VodeParam.Cache):
                (e, pred), g = pxf.value_and_grad(
                    pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                    has_aux=True,
                )(local_energy)(x, y, model=model)
            optim_h.step(
                model,
                pxu.Mask(pxu.m(pxc.VodeParam).has_not(latent=True))(g["model"]),
                True,
            )
            optim_h_latent.step(
                model, pxu.Mask(pxu.m(pxc.VodeParam).has(latent=True))(g["model"]), True
            )
            return (x, y), None

        model.train()

        if mode == "constrained":
            model.vodes[0].h.frozen = True
            model.vodes[-1].h.frozen = True
        elif mode == "label-only":
            model.vodes[0].h.frozen = True
            model.vodes[-1].h.frozen = False
            if not is_up_initialisation:
                y = None
            else:
                assert y is not None
            # y = None
            # is_up_initialisation = False
        elif mode == "data-only":
            model.vodes[0].h.frozen = False
            model.vodes[-1].h.frozen = True
            if is_up_initialisation:
                x = None
            else:
                assert x is not None
            # x = None
            # is_up_initialisation = True
        elif mode == "unconstrained":
            model.vodes[0].h.frozen = False
            model.vodes[-1].h.frozen = False
            x = None
            y = None

        if latent is not None:
            model.latent_vode.h.frozen = True

        # Init step
        with pxu.step(model, init, clear_params=pxc.VodeParam.Cache):
            initialisation_latent(
                x, y, latent, model=model, is_up_initialisation=is_up_initialisation
            )

        optim_h.init(
            pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True).has_not(latent=True))(
                model
            )
        )
        optim_h_latent.init(
            pxu.Mask(pxu.m(pxc.VodeParam).has(latent=True).has_not(frozen=True))(model)
        )

        # Inference steps
        pxf.scan(h_step, xs=jax.numpy.arange(T))(
            x, y, model=model, optim_h=optim_h, optim_h_latent=optim_h_latent
        )

        optim_h.clear()
        optim_h_latent.clear()

        # restore frozen states
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = True
        model.latent_vode.h.frozen = False
        return model.vodes[0].get("h"), model.vodes[-1].get("h")

    if is_free_latents:
        local_infer_on_batch = infer_on_batch_with_free_latent
    else:
        local_infer_on_batch = infer_on_batch

    if optim_h_latent is not None:
        local_infer_on_batch = partial(
            local_infer_on_batch, optim_h_latent=optim_h_latent
        )

    return local_infer_on_batch


@pxf.jit(static_argnums=(0, 3, 4))
def infer_on_batch(
    T: int,
    x: jax.Array,
    y: jax.Array,
    is_up_initialisation: bool,
    mode: int,
    *,
    model: Model,
    optim_h: pxu.Optim,
):
    mode_mapping = {
        0: "constrained",
        1: "label-only",
        2: "data-only",
        3: "unconstrained",
    }
    mode = mode_mapping.get(mode, mode)

    def h_step(i, x, y, *, model, optim_h):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, (y_down, x_up)), g = pxf.value_and_grad(
                pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                has_aux=True,
            )(energy)(x, y, model=model)
        optim_h.step(model, g["model"], True)  #
        return (x, y), None

    model.train()

    if mode == "constrained":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = True
    elif mode == "label-only":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = False
        y = None
        # is_up_initialisation = False
    elif mode == "data-only":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = True
        x = None
        # is_up_initialisation = True
    elif mode == "unconstrained":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = False
        x = None
        y = None

    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        initialisation(x, y, model=model, is_up_initialisation=is_up_initialisation)

    optim_h.init(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True))(model))

    # Inference steps
    pxf.scan(h_step, xs=jax.numpy.arange(T))(x, y, model=model, optim_h=optim_h)

    optim_h.clear()

    # restore frozen states
    model.vodes[0].h.frozen = True
    model.vodes[-1].h.frozen = True
    return model.vodes[0].get("h"), model.vodes[-1].get("h")


@pxf.jit(static_argnums=(0, 3, 4, 5, 6))
def infer_on_batch_latent_gen(
    T: int,
    x: jax.Array,
    y: jax.Array,
    is_up_initialisation: bool,
    mode: int,
    e_th_up: float,
    e_th_down: float,
    *,
    model: Model,
    optim_h: pxu.Optim,
):
    """
    For optimal working:
    - set T to a multiple of track_indent
    """

    mode_mapping = {
        0: "constrained",
        1: "label-only",
        2: "data-only",
        3: "unconstrained",
    }
    mode = mode_mapping.get(mode, mode)

    def infer(i, x, y, keep_computing, *, model, optim_h):
        def h_step(x, y, *, model, optim_h):
            with pxu.step(model, clear_params=pxc.VodeParam.Cache):
                (e, (y_down, x_up)), g = pxf.value_and_grad(
                    pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                    has_aux=True,
                )(energy)(x, y, model=model)
            optim_h.step(model, g["model"], True)
            e_up, e_down = energy_per_stream(x, y, model=model)
            return e_up, e_down

        eup, edown = pxf.cond(h_step, lambda x, y, *, model, optim_h: (0.0, 0.0))(
            keep_computing, x, y, model=model, optim_h=optim_h
        )
        new_keep_computing = keep_computing & ((eup > e_th_up) | (edown > e_th_down))
        return (x, y, new_keep_computing), None

    model.train()

    if mode == "constrained":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = True
    elif mode == "label-only":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = False
        y = None
        # is_up_initialisation = False
    elif mode == "data-only":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = True
        x = None
        # is_up_initialisation = True
    elif mode == "unconstrained":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = False
        x = None
        y = None

    # Init step
    with pxu.step(model, "latent_gen", clear_params=pxc.VodeParam.Cache):
        initialisation(x, y, model=model, is_up_initialisation=is_up_initialisation)

    optim_h.init(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True))(model))

    keep_compute = jnp.array(True)
    pxf.scan(infer, xs=jax.numpy.arange(T))(
        x, y, keep_compute, model=model, optim_h=optim_h
    )

    x_out, y_out = model.vodes[0].get("h"), model.vodes[-1].get("h")
    e_up, e_down = energy_per_stream(x, y, model=model)
    optim_h.clear()

    # restore frozen states
    model.vodes[0].h.frozen = True
    model.vodes[-1].h.frozen = True
    return x_out, y_out, e_up, e_down


@pxf.jit(static_argnums=(0, 3, 4, 5))
def infer_on_batch_latent_gen_no_init(
    T: int,
    x: jax.Array,
    y: jax.Array,
    mode: int,
    e_th_up: float,
    e_th_down: float,
    *,
    model: Model,
    optim_h: pxu.Optim,
):
    """
    For optimal working:
    - set T to a multiple of track_indent
    """

    mode_mapping = {
        0: "constrained",
        1: "label-only",
        2: "data-only",
        3: "unconstrained",
    }
    mode = mode_mapping.get(mode, mode)

    def infer(i, x, y, keep_computing, *, model, optim_h):
        def h_step(x, y, *, model, optim_h):
            with pxu.step(model, clear_params=pxc.VodeParam.Cache):
                (e, (y_down, x_up)), g = pxf.value_and_grad(
                    pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                    has_aux=True,
                )(energy)(x, y, model=model)
            optim_h.step(model, g["model"], True)
            e_up, e_down = energy_per_stream(x, y, model=model)
            return e_up, e_down

        eup, edown = pxf.cond(h_step, lambda x, y, *, model, optim_h: (0.0, 0.0))(
            keep_computing, x, y, model=model, optim_h=optim_h
        )
        new_keep_computing = keep_computing & ((eup > e_th_up) | (edown > e_th_down))
        return (x, y, new_keep_computing), None

    model.train()

    if mode == "constrained":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = True
    elif mode == "label-only":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = False
        y = None
        # is_up_initialisation = False
    elif mode == "data-only":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = True
        x = None
        # is_up_initialisation = True
    elif mode == "unconstrained":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = False
        x = None
        y = None

    optim_h.init(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True))(model))

    keep_compute = jnp.array(True)
    pxf.scan(infer, xs=jax.numpy.arange(T))(
        x, y, keep_compute, model=model, optim_h=optim_h
    )

    x_out, y_out = model.vodes[0].get("h"), model.vodes[-1].get("h")
    e_up, e_down = energy_per_stream(x, y, model=model)
    optim_h.clear()

    # restore frozen states
    model.vodes[0].h.frozen = True
    model.vodes[-1].h.frozen = True
    return x_out, y_out, e_up, e_down


@pxf.jit(static_argnums=(0, 3, 4, 5))
def infer_on_batch_latent_gen_no_init_median(
    T: int,
    x: jax.Array,
    y: jax.Array,
    mode: int,
    e_th_up: float,
    e_th_down: float,
    *,
    model: Model,
    optim_h: pxu.Optim,
):
    """
    For optimal working:
    - set T to a multiple of track_indent
    """

    mode_mapping = {
        0: "constrained",
        1: "label-only",
        2: "data-only",
        3: "unconstrained",
    }
    mode = mode_mapping.get(mode, mode)

    def infer(i, x, y, keep_computing, *, model, optim_h):
        def h_step(x, y, *, model, optim_h):
            with pxu.step(model, clear_params=pxc.VodeParam.Cache):
                (e, (y_down, x_up)), g = pxf.value_and_grad(
                    pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                    has_aux=True,
                )(energy)(x, y, model=model)
            optim_h.step(model, g["model"], True)
            e_up, e_down = energy_per_stream_median(x, y, model=model)
            return e_up, e_down

        eup, edown = pxf.cond(h_step, lambda x, y, *, model, optim_h: (0.0, 0.0))(
            keep_computing, x, y, model=model, optim_h=optim_h
        )
        new_keep_computing = keep_computing & ((eup > e_th_up) | (edown > e_th_down))
        return (x, y, new_keep_computing), None

    model.train()

    if mode == "constrained":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = True
    elif mode == "label-only":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = False
        y = None
        # is_up_initialisation = False
    elif mode == "data-only":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = True
        x = None
        # is_up_initialisation = True
    elif mode == "unconstrained":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = False
        x = None
        y = None

    optim_h.init(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True))(model))

    keep_compute = jnp.array(True)
    pxf.scan(infer, xs=jax.numpy.arange(T))(
        x, y, keep_compute, model=model, optim_h=optim_h
    )

    x_out, y_out = model.vodes[0].get("h"), model.vodes[-1].get("h")
    e_up, e_down = energy_per_stream(x, y, model=model)
    optim_h.clear()

    # restore frozen states
    model.vodes[0].h.frozen = True
    model.vodes[-1].h.frozen = True
    return x_out, y_out, e_up, e_down


@pxf.jit(static_argnums=(0, 3))
def infer_on_batch_no_init(
    T: int,
    x: jax.Array,
    y: jax.Array,
    mode: int,
    *,
    model: Model,
    optim_h: pxu.Optim,
):
    mode_mapping = {
        0: "constrained",
        1: "label-only",
        2: "data-only",
        3: "unconstrained",
    }
    mode = mode_mapping.get(mode, mode)

    def h_step(i, x, y, *, model, optim_h):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, (y_down, x_up)), g = pxf.value_and_grad(
                pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                has_aux=True,
            )(energy)(x, y, model=model)
        optim_h.step(model, g["model"], True)
        return (x, y), None

    model.train()

    if mode == "constrained":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = True
    elif mode == "label-only":
        model.vodes[0].h.frozen = True
        model.vodes[-1].h.frozen = False
        y = None
        # is_up_initialisation = False
    elif mode == "data-only":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = True
        x = None
        # is_up_initialisation = True
    elif mode == "unconstrained":
        model.vodes[0].h.frozen = False
        model.vodes[-1].h.frozen = False
        x = None
        # y = None  # one should be kept for vmap but will be ignored because there is no init

    optim_h.init(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True))(model))

    # Inference steps
    pxf.scan(h_step, xs=jax.numpy.arange(T))(x, y, model=model, optim_h=optim_h)

    optim_h.clear()

    # restore frozen states
    model.vodes[0].h.frozen = True
    model.vodes[-1].h.frozen = True
    return model.vodes[0].get("h"), model.vodes[-1].get("h")


class ConvTranspose(Layer):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        padding: Union[str, int, Sequence[int], Sequence[tuple[int, int]]] = 0,
        output_padding: Union[int, Sequence[int]] = 0,
        dilation: Union[int, Sequence[int]] = 1,
        groups: int = 1,
        use_bias: bool = True,
        padding_mode: str = "ZEROS",
        dtype=None,
        *,
        rkg: RandomKeyGenerator = RKG,
    ):
        super().__init__(
            eqx.nn.ConvTranspose,
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            dilation=dilation,
            groups=groups,
            use_bias=use_bias,
            padding_mode=padding_mode,
            key=rkg(),
        )


class CNNModel(pxc.EnergyModule):
    def __init__(
        self,
        input_size: Tuple[int, int],
        output_size: int,
        input_channels: int,
        cnn_name: str,
        activation: str,
        input_var=1.0,
        alpha_up=1.0,
        alpha_down=1.0,
        activity_init="ff",
        activity_init_kwargs={},
        is_supervised=True,
        cov_output=False,
        out_activation_down="tanh",
        out_activation_up=None,
    ) -> None:
        super().__init__()

        # Define CNN configurations
        cfg = {
            "SmallVGG": [32, 64, 128, 256, 512],
            "SmallVGGreverse": [256, 128, 64, 32, input_channels],
            "SmallVGGpool": [32, "M", 64, "M", 128, "M", 256, "M", 512, "M"],
            "SmallVGGpoolreverse": [
                "M",
                256,
                "M",
                128,
                "M",
                64,
                "M",
                32,
                "M",
                input_channels,
            ],
            "SmallVGGavg": [32, "A", 64, "A", 128, "A", 256, "A", 512, "A"],
            "SmallVGGavgreverse": [
                "A",
                256,
                "A",
                128,
                "A",
                64,
                "A",
                32,
                "A",
                input_channels,
            ],
            "MidVGG": [64, 128, 256, 512, 512],
            "MidVGGreverse": [512, 256, 128, 64, input_channels],
            "VGG5": [128, "M", 256, "M", 512, "M", 512, "M"],
            "VGG5reverse": ["M", 512, "M", 256, "M", 128, "M", input_channels],
            "VGG5avg": [128, "A", 256, "A", 512, "A", 512, "A"],
            "VGG5avgreverse": ["A", 512, "A", 256, "A", 128, "A", input_channels],
            "VGG5np": [128, 256, 512, 512],
            "VGG5npreverse": [512, 256, 128, input_channels],
            "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
            "VGG11reverse": [
                "M",
                512,
                512,
                "M",
                512,
                256,
                "M",
                256,
                128,
                "M",
                64,
                "M",
                input_channels,
            ],
            "MiniVGG": [32, "M", 64, "M", 128, "M", 256, "M", 256, "M"],
            "MiniVGGreverse": [
                "M",
                256,
                "M",
                128,
                "M",
                64,
                "M",
                32,
                "M",
                input_channels,
            ],
            "MonoVGG": [64, "M", 128, "M", 256, "M", 512, "M", 512, "M"],
            "MonoVGGreverse": [
                "M",
                512,
                "M",
                256,
                "M",
                128,
                "M",
                64,
                "M",
                input_channels,
            ],
            "TestCNN": [64, 128, 128],
            "TestCNNreverse": [128, 64, input_channels],
        }

        # if cnn_name is a list, then it is the configuration of the CNN
        if isinstance(cnn_name, list):
            cfg["manual"] = cnn_name
            tmp = [input_channels] + cnn_name.copy()
            tmp = [t for t in tmp if t != "M"]
            cfg["manualreverse"] = []
            cnn_layer = -2
            for elem in cnn_name[::-1]:
                if elem == "M" or elem == "A":
                    cfg["manualreverse"].append(elem)
                else:
                    cfg["manualreverse"].append(tmp[cnn_layer])
                    cnn_layer -= 1
            cnn_name = "manual"

        # set stride
        stride = 1 if np.any([x in {"M", "A"} for x in cfg[cnn_name]]) else 2

        # set output cov if needed
        if cov_output:
            cfg[cnn_name] = [cfg[cnn_name][0]] + cfg[cnn_name]
            cfg[cnn_name + "reverse"] = (
                cfg[cnn_name + "reverse"][:-1] + [cfg[cnn_name][0]] + [input_channels]
            )

        # Activation function
        if activation == "relu":
            activation_fn = jax.nn.relu
        elif activation == "tanh":
            activation_fn = jax.nn.tanh
        elif activation == "silu":
            activation_fn = jax.nn.silu
        elif activation == "l-relu":
            activation_fn = jax.nn.leaky_relu
        elif activation == "h-tanh":
            activation_fn = jax.nn.hard_tanh
        else:
            activation_fn = getattr(jax.nn, activation)

        self.activation = px.static(activation_fn)
        self.out_activation_down = (
            px.static(lambda x: x)
            if out_activation_down is None
            else px.static(getattr(jax.nn, out_activation_down))
        )
        self.out_activation_up = (
            px.static(lambda x: 0.0 * x)
            if not is_supervised
            else (
                px.static(lambda x: x)
                if out_activation_up is None
                else px.static(getattr(jax.nn, out_activation_up))
            )
        )  # 0.0 because vodes[0] contains constant zeros so vodes[1] should be able to predict it

        # Energy function for the output layer
        def se_energy_input(vode, rkg: px.RandomKeyGenerator = px.RKG):
            e = vode.get("h") - vode.get("u")
            return 0.5 * (e * e) / input_var

        # Initialize rulesets and transformations
        if activity_init == "ff":
            ruleset = {}
            tforms = {}
            tforms_out = {}
            tforms_in = {}
        elif activity_init == "zero":
            ruleset = {pxc.STATUS.INIT: ("h, u <- u:to_zero",)}
            tforms = {"to_zero": lambda n, k, v, rkg: jnp.zeros(n.shape.get())}
            tforms_out = {"to_zero": lambda n, k, v, rkg: -jnp.ones(n.shape.get())}
            tforms_in = {
                "to_zero": lambda n, k, v, rkg: (
                    jnp.zeros(n.shape.get())
                    if is_supervised
                    else 0.1 * jnp.ones(n.shape.get())
                )
            }
        elif activity_init == "randn":
            ruleset = {pxc.STATUS.INIT: ("h, u <- u:randn",)}
            tforms = {
                "randn": lambda n, k, v, rkg: jax_random.normal(rkg(), n.shape.get())
            }
            tforms_out = {
                "randn": lambda n, k, v, rkg: jax_random.uniform(
                    rkg(), n.shape.get(), minval=-1, maxval=1
                )
            }
            tforms_in = {
                "randn": lambda n, k, v, rkg: (
                    jax_random.normal(rkg(), n.shape.get())
                    if not is_supervised
                    else 0.1 * jnp.ones(n.shape.get())
                )
            }
        elif activity_init == "noisy-ff":
            layer_var = activity_init_kwargs["layer_var"]
            ruleset = {pxc.STATUS.INIT: ("h, u <- u:ff_randn",)}
            tforms = {
                "ff_randn": lambda n, k, v, rkg: v
                + layer_var * jax_random.normal(rkg(), n.shape.get())
            }
            tforms_out = {
                "ff_randn": lambda n, k, v, rkg: v
                + input_var * layer_var * jax_random.normal(rkg(), n.shape.get())
            }
            tforms_in = {
                "ff_randn": lambda n, k, v, rkg: v
                + layer_var * jax_random.normal(rkg(), n.shape.get())
            }
        elif activity_init == "xavier":
            ruleset = {pxc.STATUS.INIT: ("h, u <- u:xav",)}
            xavier_init = lambda n, k, v, rkg: jax.random.uniform(
                rkg(),
                shape=(n.shape.get()),
                minval=-jnp.sqrt(6 / n.shape.get()[0]),
                maxval=jnp.sqrt(6 / n.shape.get()[0]),
            )
            tforms = {"xav": xavier_init}
            tforms_out = {"xav": xavier_init}
            tforms_in = {"xav": lambda n, k, v, rkg: 0.1 * jnp.ones(n.shape.get())}

        ## Initialize vodes
        self.vodes = []
        h, w = input_size
        in_channels = input_channels
        # cfg_up = cfg[cnn_name]
        for idx, x in enumerate(cfg[cnn_name]):
            if isinstance(x, int):
                stride_loc = 1 if (idx == 0 and cov_output) else stride
                self.vodes.append(
                    pxc.Vode(
                        (in_channels, h, w),
                        energy_fn=se_energy_input if idx == 0 else se_energy,
                        ruleset=ruleset,
                        tforms=tforms if idx != 0 else tforms_out,
                    )
                )
                in_channels = x
                h //= stride_loc
                w //= stride_loc
            elif x in {"M", "A"}:
                h //= 2
                w //= 2
            else:
                raise ValueError(f"Invalid value {x} in CNN configuration")
        self.vodes += [
            pxc.Vode(
                ((in_channels * h * w,)),
                ruleset=ruleset,
                tforms=tforms,
            ),
            pxc.Vode(
                (output_size,),
                ruleset=ruleset,
                tforms=tforms_in if is_supervised else tforms,
            ),
            *(
                [
                    pxc.Vode(
                        (output_size,),
                        ruleset=ruleset,
                        tforms=tforms_in,
                    )
                ]
                if not is_supervised
                else []
            ),
        ]

        # convert vode to same order as MLP eg. vodes[0] is label/latent and vodes[-1] is img
        self.vodes = self.vodes[::-1]

        # Build the up model layers and vodes
        self.up = []
        h, w = input_size
        idx_vode = -1
        in_channels = input_channels
        for x in cfg[cnn_name]:
            if x == "M":
                self.up.append(pxnn.MaxPool2d(kernel_size=2, stride=2))
                h //= 2
                w //= 2
            elif x == "A":
                self.up.append(pxnn.AvgPool2d(kernel_size=2, stride=2))
                h //= 2
                w //= 2
            else:
                stride_loc = (
                    1 if (in_channels == input_channels and cov_output) else stride
                )
                self.up += [
                    self.vodes[idx_vode],
                    pxnn.Conv2d(
                        in_channels, x, kernel_size=3, padding=1, stride=stride_loc
                    ),
                    self.activation,
                ]
                idx_vode -= 1
                in_channels = x
                h //= stride_loc
                w //= stride_loc
        # Flatten and add linear layer to output_size
        flatten_size = in_channels * h * w
        self.up += [
            px.static(
                lambda x, flatten_size=flatten_size: jnp.reshape(x, (flatten_size,))
            ),
            self.vodes[idx_vode],
            pxnn.Linear(flatten_size, output_size),
            *(
                [
                    self.vodes[1],
                    px.static(lambda x: x),
                ]
                if not is_supervised
                else []
            ),
            self.out_activation_up,
            self.vodes[0],
        ]

        # Build the down model layers
        zero_prior = px.static(lambda x: x)
        self.down = [
            self.vodes[0],
            *(
                [
                    zero_prior,  # equivalent of Identity layer
                    zero_prior,  # equivalent of activation function
                    self.vodes[1],
                ]
                if not is_supervised
                else []
            ),
            pxnn.Linear(output_size, flatten_size),
            self.activation,
            self.vodes[1 if is_supervised else 2],
            px.static(
                lambda x, in_channels=in_channels, h=h, w=w: jnp.reshape(
                    x, (in_channels, h, w)
                )
            ),
        ]
        idx_vode = 2 if is_supervised else 3
        down_stride = 2
        output_padding = down_stride - 1
        for idx, x in enumerate(cfg[cnn_name + "reverse"]):
            if x in {"M", "A"}:
                pass
            else:
                if x == input_channels and cov_output:
                    l = pxnn.Conv2d(in_channels, x, kernel_size=3, padding=1, stride=1)
                else:
                    l = ConvTranspose(
                        2,
                        in_channels,
                        x,
                        kernel_size=3,
                        padding=1,
                        stride=down_stride,
                        output_padding=output_padding,
                    )
                self.down += [
                    l,
                    (
                        self.activation
                        if idx != len(cfg[cnn_name + "reverse"]) - 1
                        else self.out_activation_down
                    ),
                    self.vodes[idx_vode],
                ]
                in_channels = x
                idx_vode += 1
                h *= down_stride
                w *= down_stride
        assert idx_vode == len(self.vodes)

        # Freeze the input and output vodes if supervised
        self.vodes[0].h.frozen = True
        self.vodes[-1].h.frozen = True

        # Set other attributes
        self.alpha_up = alpha_up
        self.alpha_down = alpha_down
        self.input_var = input_var
        self.key = px.RKG()

    def model_down(self, x):
        input = self.vodes[0].get("h")
        for l in self.down:
            input = l(input)
        return self.vodes[-1].get("u")

    def model_up(self, y):
        input = self.vodes[-1].get("h")
        for l in self.up:
            input = l(input)
        return self.vodes[0].get("u")

    def model_down_fp(self, x, noise_var: float = 0.0):
        input = x
        for l in self.down[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    # flatten input
                    input_shape = input.shape
                    input = jnp.reshape(input, (-1))
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones(len(input)) / self.alpha_down,
                        px.RKG(),
                    )
                    input = jnp.reshape(input, input_shape)
            else:
                input = l(input)
        return input

    def model_up_fp(self, y, noise_var: float = 0.0):
        input = y
        for l in self.up[1:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    # flatten input
                    input_shape = input.shape
                    input = jnp.reshape(input, (-1))
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones_like(input) / self.alpha_up,
                        px.RKG(),
                    )
                    input = jnp.reshape(input, input_shape)
            else:
                input = l(input)
        return input

    def hidden_to_input_fp(self, input, noise_var: float = 0.0):
        for l in self.down[3:-1]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input_shape = input.shape
                    input = jnp.reshape(input, (-1))
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones_like(input) / self.alpha_up,
                        px.RKG(),
                    )
                    input = jnp.reshape(input, input_shape)
            else:
                input = l(input)
        return input

    def input_to_hidden_fp(self, input, noise_var: float = 0.0):
        for l in self.up[1:-4]:
            if isinstance(l, pxc.Vode):
                if noise_var > 0:
                    input_shape = input.shape
                    input = jnp.reshape(input, (-1))
                    input = sample_multivariate_Gauss_diag_cov(
                        input,
                        noise_var * jnp.ones_like(input) / self.alpha_up,
                        px.RKG(),
                    )
                    input = jnp.reshape(input, input_shape)
            else:
                input = l(input)
        return input

    def __call__(self, x, y, is_up_initialisation: bool = True):
        if x is not None:
            self.vodes[0].set("h", x)
        if y is not None:
            self.vodes[-1].set("h", y)

        if is_up_initialisation:
            output = self.model_up(y)
        else:
            output = self.model_down(x)

        if x is not None:
            self.vodes[0].set("h", x)
        if y is not None:
            self.vodes[-1].set("h", y)
        return output


class AE(pxc.EnergyModule):
    def __init__(
        self,
        model: CNNModel,
    ):
        super().__init__()
        self.down = model.down
        self.up = model.up

        self.vodes = model.vodes

        self.alpha_up = model.alpha_up
        self.alpha_down = model.alpha_down

    def __call__(self, x, y, is_up_initialisation: bool = True):
        if x is not None:
            self.vodes[0].set("h", x)
        if y is not None:
            self.vodes[-1].set("h", y)
        return x

    def model_down(self, x):
        # image to latent
        input = self.vodes[-1].get("h")
        for l in self.up[:-3]:
            input = l(input)

        # set latent regularisation
        self.vodes[0](input)

        # latent to image
        for l in self.down[3:]:
            input = l(input)

        return self.vodes[-1].get("u")

    def model_up(self, y):
        pass  # only has one energy

    def model_down_fp(self, x, noise_var: float = 0.0):
        input = x
        for l in self.down[3:]:
            if not isinstance(l, pxc.Vode):
                input = l(input)
        return input

    def model_up_fp(self, y, noise_var: float = 0.0):
        input = y
        for l in self.up[:-3]:
            if not isinstance(l, pxc.Vode):
                input = l(input)
        return input

    def hidden_to_input_fp(self, input, noise_var: float = 0.0):
        return self.model_down_fp(input, noise_var=noise_var)

    def input_to_hidden_fp(self, input, noise_var: float = 0.0):
        return self.model_up_fp(input, noise_var=noise_var)


@pxf.vmap(
    pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)),
    in_axes=(0, 0),
    out_axes=(None, 0, 0),
    axis_name="batch",
)
def energy_ae_latent(x, y, *, model: Model):
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        x_up, latent = model.model_up(y)
        energy_up = model.energy() * model.alpha_up
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        y_down = model.model_down(x, latent)
        energy_down = model.energy() * model.alpha_down
    return jax.lax.pmean(energy_up + energy_down, "batch"), y_down, x_up


class AE_Latent(pxc.EnergyModule):
    def __init__(
        self,
        model: AddLatent,
    ):
        super().__init__()
        self.down = model.down
        self.up = model.up

        self.vodes = model.vodes

        self.alpha_up = model.alpha_up
        self.alpha_down = model.alpha_down

        self.latent_layer_up = model.latent_layer_up
        self.latent_layer_down = model.latent_layer_down

        self.grad_transform = model.grad_transform

        self.combination_fn = model.combination_fn
        self.combination_idx = model.combination_idx

        self.hidden_to_input_fp = px.static(model.hidden_to_input_fp)
        self.model_down_fp = px.static(model.model_down_fp)
        self.model_up_fp = px.static(model.model_up_fp)

        # repurpose the latent vode to add regularistion to the latent space
        self.latent_vode = model.latent_vode

    def __call__(self, x, y, is_up_initialisation: bool = True, latent=None):
        if latent is None:
            self.latent_vode.set("h", jnp.zeros(self.latent_vode.shape))
        else:
            self.latent_vode.set("h", latent)
        if x is not None:
            self.vodes[0].set("h", x)
        if y is not None:
            self.vodes[-1].set("h", y)
        return x

    def model_up(self, y):
        # set latent regularisation to zero energy
        self.latent_vode(self.latent_vode.get("h"))

        # get latent and classification
        input = self.up[0].get("h")
        for l in self.up[:-3]:
            input = l(input)
        self.up[-1](self.up[-2](self.up[-3](input)))
        latent = self.latent_layer_up(self.grad_transform(input))
        return self.vodes[0].get("u"), latent

    def model_down(self, x, latent):
        # setup latent regularisation
        self.latent_vode(latent)

        # reconstruct
        input = self.down[0].get("h")
        for idx, l in enumerate(self.down):
            if idx == self.combination_idx:
                input = self.combination_fn(input, latent, l, self.latent_layer_down)
            else:
                input = l(input)
        return self.down[-1].get("u")
