import tempfile
import warnings

import numpy as np
import torch
import torchvision
from torchvision.utils import save_image

import os
from tqdm import tqdm
import matplotlib.pyplot as plt

from pcax.predictive_coding._vode import Ruleset
from pytorch_fid.fid_score_natural_imgs import fid_natural, save_stats_natural
from pytorch_fid.fid_score_mnist import save_stats_mnist, fid_mnist
from pytorch_fid.fid_score_fashion_mnist import (
    save_stats_fashion_mnist,
    fid_fashion_mnist,
)
from inception_score import (
    get_mnist_inception_score,
    get_fashion_mnist_inception_score,
    get_cifar_inception_score,
)

import jax
import jax.numpy as jnp
import jax.random as jax_random
import pcax.predictive_coding as pxc
import pcax.utils as pxu
import pcax as px

from utils_pcax.models import (
    SubModel,
    energy_per_data,
    energy_per_stream,
    energy_per_stream_median,
    fp_down_latents,
    fp_down_w_latent,
    fp_up_latents,
    infer_on_batch_latent_gen,
    infer_on_batch_latent_gen_no_init,
    infer_on_batch_latent_gen_no_init_median,
    initialisation,
    infer_on_batch,
    fp_down,
    fp_up,
)


def eval_fid_inception_latent(imgs, dataset, data_dl, verbose=True, subset="val"):
    # check if summary statistics of test dataset used for FID exist
    data_folder = "./data"
    data_subset_folder = data_folder + "/" + dataset + "_" + subset
    data_subset_filename = data_subset_folder + ".npz"
    if not os.path.exists(data_subset_filename):
        os.makedirs(data_subset_folder, exist_ok=True)
        print(data_subset_folder + " does not exist")
        print("Creating compressed files for faster FID measure ...")
        make_compressed_files(data_dl, data_subset_folder, dataset)

    imgs = imgs / 2 + 0.5  # from -1 -> 1 to 0 -> 1
    imgs = np.clip(imgs, 0, 1)
    imgs = torch.tensor(imgs)
    # save images
    with tempfile.TemporaryDirectory() as img_folder:
        for img_idx in range(len(imgs)):
            save_image(imgs[img_idx], img_folder + "/" + str(img_idx) + ".png")

        if dataset == "mnist":
            inception_fn = get_mnist_inception_score
            fid_fn = fid_mnist
        elif dataset == "fashion_mnist":
            inception_fn = get_fashion_mnist_inception_score
            fid_fn = fid_fashion_mnist
        elif dataset == "cifar10":
            inception_fn = get_cifar_inception_score
            fid_fn = fid_natural

        is_mean, is_std = inception_fn(img_folder)
        fid = fid_fn(
            data_subset_filename,
            img_folder,
            device=torch.device("cpu"),
            num_workers=0,
            verbose=False,
        )

    if verbose:
        print(f"Inception score: {is_mean:.2f} +/- {is_std:.2f} - FID: {fid:.2f}")

    return is_mean, fid, imgs.numpy()


def eval_rmse(
    model,
    batch_size,
    dataset,
    datasubset,
    mode="fp",
    mode_kwargs={"noise_var": 0.0},
    verbose=True,
):
    img_folder = "mean_images"

    datasubset_name = datasubset
    if not (isinstance(model.vodes[-1].shape, int) or len(model.vodes[-1].shape) == 1):
        datasubset_name += "_cnn"

    # load correct mean images
    path = f"{img_folder}/{dataset}_{datasubset_name}.pt"
    # try to load mean images until it works
    open_state = False
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", FutureWarning)
        for i in range(100):
            try:
                mean_images = torch.load(path).numpy()
                open_state = True
                break
            except:
                pass
        if not open_state:
            print("Could not load mean images")
            return

    MAP = eval_gen(model, batch_size, mode=mode, mode_kwargs=mode_kwargs, verbose=False)
    MAP = MAP / 2 + 0.5

    rmse = np.sqrt(((mean_images - MAP) ** 2).mean())

    if verbose:
        print(f"RMSE: {rmse:.4f}")
    return rmse, MAP


def make_compressed_files(data_dl, data_folder, dataset):
    # save test images
    test_img_folder = data_folder
    compressed_filename = test_img_folder + ".npz"
    os.makedirs(test_img_folder, exist_ok=True)

    if dataset == "mnist":
        shape = (28, 28)
        save_stats = save_stats_mnist
    elif dataset == "fashion_mnist":
        shape = (28, 28)
        save_stats = save_stats_fashion_mnist
    elif dataset == "cifar10":
        shape = (3, 32, 32)
        save_stats = save_stats_natural
    else:
        raise ValueError("Dataset not supported")

    count = 0
    for X, y in tqdm(data_dl):
        images = y.reshape(-1, *shape)
        images = torch.tensor(images / 2 + 0.5)  # remove normalisation
        for img_idx in range(len(images)):
            save_image(images[img_idx], test_img_folder + "/" + str(count) + ".png")
            count += 1

    # get and save summary statistics of test images
    save_stats(test_img_folder, compressed_filename)


def make_compressed_MNIST_files(data_dl, data_folder):
    # save test images
    test_img_folder = data_folder
    compressed_filename = test_img_folder + ".npz"
    os.makedirs(test_img_folder, exist_ok=True)

    count = 0
    for X, y in tqdm(data_dl):
        images = y.reshape(-1, 28, 28)
        images = torch.tensor(images / 2 + 0.5)  # remove normalisation
        for img_idx in range(len(images)):
            save_image(images[img_idx], test_img_folder + "/" + str(count) + ".png")
            count += 1
    # get and save summary statistics of test images
    save_stats_mnist(test_img_folder, compressed_filename)


def make_compressed_FashionMNIST_files(data_dl, data_folder):
    # save test images
    test_img_folder = data_folder
    compressed_filename = test_img_folder + ".npz"
    os.makedirs(test_img_folder, exist_ok=True)

    count = 0
    for X, y in tqdm(data_dl):
        images = y.reshape(-1, 28, 28)
        images = torch.tensor(images / 2 + 0.5)  # remove normalisation
        for img_idx in range(len(images)):
            save_image(images[img_idx], test_img_folder + "/" + str(count) + ".png")
            count += 1
    # get and save summary statistics of test images
    save_stats_fashion_mnist(test_img_folder, compressed_filename)


def eval_fid_inception(
    dataset,
    data_dl,
    nm_samples,
    model,
    latent_size,
    batch_size,
    mode="fp",
    mode_kwargs={"noise_var": 1.0},
    verbose=True,
    subset="val",
    labels=False,
):
    # check if summary statistics of test dataset used for FID exist
    data_folder = "./data"
    data_subset_folder = data_folder + "/" + dataset + "_" + subset
    data_subset_filename = data_subset_folder + ".npz"
    if not os.path.exists(data_subset_filename):
        os.makedirs(data_subset_folder, exist_ok=True)
        print(data_subset_folder + " does not exist")
        print("Creating compressed files for faster FID measure ...")
        make_compressed_files(data_dl, data_subset_folder, dataset)

    # generate images from model
    if not labels:
        imgs = eval_sto_gen(
            nm_samples,
            model,
            latent_size,
            batch_size,
            mode=mode,
            mode_kwargs=mode_kwargs,
            verbose=False,
        )
    else:
        imgs = eval_gen(
            model,
            batch_size,
            mode=mode,
            mode_kwargs=mode_kwargs,
            verbose=False,
            nm_samples=nm_samples,
        )
    imgs = imgs / 2 + 0.5  # from -1 -> 1 to 0 -> 1
    imgs = np.clip(imgs, 0, 1)
    imgs = torch.tensor(imgs)
    # save images
    with tempfile.TemporaryDirectory() as img_folder:
        for img_idx in range(len(imgs)):
            save_image(imgs[img_idx], img_folder + "/" + str(img_idx) + ".png")

        if dataset == "mnist":
            inception_fn = get_mnist_inception_score
            fid_fn = fid_mnist
        elif dataset == "fashion_mnist":
            inception_fn = get_fashion_mnist_inception_score
            fid_fn = fid_fashion_mnist
        elif dataset == "cifar10":
            inception_fn = get_cifar_inception_score
            fid_fn = fid_natural

        is_mean, is_std = inception_fn(img_folder)
        fid = fid_fn(
            data_subset_filename,
            img_folder,
            device=torch.device("cpu"),
            num_workers=0,
            verbose=False,
        )

    if verbose:
        print(f"Inception score: {is_mean:.2f} +/- {is_std:.2f} - FID: {fid:.2f}")

    return is_mean, fid, imgs


def eval_sto_gen(
    nm_samples,
    model,
    latent_size,
    batch_size,
    mode="fp",
    mode_kwargs={"noise_var": 1.0},
    verbose=True,
):
    # get the number of batches neede to get nm_samples
    nm_batches = np.ceil(nm_samples / batch_size).astype(int)
    samples = []
    X_test = np.zeros((batch_size, latent_size))
    batches = range(nm_batches)

    shape = (
        (28, 28)
        if model.vodes[-1].shape == (784,) or model.vodes[-1].shape == 784
        else model.vodes[-1].shape
    )

    if verbose:
        batches = tqdm(batches, desc="Generating samples: ")
    for i in batches:
        if mode == "fp":
            pred_down = fp_down(X_test, model=model, noise_var=mode_kwargs["noise_var"])
        elif mode == "inference":
            mode_kwargs["infer_on_batch"](
                mode_kwargs["T"],
                X_test,
                None,
                mode_kwargs["is_up_initialisation"],
                1,
                model=model,
                optim_h=mode_kwargs["optim_h"],
            )[1]
            pred_down = initialisation(
                jax.numpy.zeros((batch_size, latent_size)),
                jax.numpy.zeros((batch_size, *model.vodes[-1].shape)),
                model=model,
                is_up_initialisation=False,
            )
        pred_down = pred_down.reshape(batch_size, *shape)
        samples.append(pred_down)

    samples = np.array(samples).reshape(-1, *shape)[:nm_samples]

    if verbose:
        fig, axs = plt.subplots(1, 10)
        for i in range(10):
            axs[i].imshow(samples[i], cmap="gray", vmin=-1, vmax=1)
            axs[i].axis("off")
        plt.show()

    return samples


def eval_gen(
    model,
    batch_size,
    mode="fp",
    mode_kwargs={"noise_var": 0.0},
    verbose=True,
    nm_samples=None,
):
    # assess generation of mean images
    nm_classes = model.vodes[0].shape[0]
    if nm_samples is None:
        nm_samples = nm_classes
    X_test = jax.nn.one_hot(np.arange(batch_size) % nm_classes, nm_classes)
    Y_test = np.zeros((batch_size, *model.vodes[-1].shape))

    nm_batches = np.ceil(nm_samples / batch_size).astype(int)
    samples = []
    batches = range(nm_batches)
    if verbose:
        batches = tqdm(batches, desc="Generating samples: ")
    for i in batches:
        if mode == "fp":
            pred_down = fp_down(X_test, model=model, noise_var=mode_kwargs["noise_var"])
        elif mode == "inference":
            pred_down = mode_kwargs["infer_on_batch"](
                mode_kwargs["T"],
                X_test,
                Y_test,
                mode_kwargs["is_up_initialisation"],
                1,
                model=model,
                optim_h=mode_kwargs["optim_h"],
            )[1]
            if model.alpha_down > 0:
                pred_down = initialisation(
                    X_test, Y_test, model=model, is_up_initialisation=False
                )
        samples.append(pred_down)
    pred_down = np.array(samples)

    pred_down = (
        pred_down.reshape(-1, 28, 28)
        if model.vodes[-1].shape == (784,) or model.vodes[-1].shape == 784
        else pred_down.reshape(-1, *model.vodes[-1].shape)
    )
    pred_down = pred_down[:nm_samples]

    if verbose:
        fig, axs = plt.subplots(1, nm_classes)
        img = pred_down / 2 + 0.5
        if len(img.shape) == 4:
            img = np.transpose(img, (0, 2, 3, 1))
        for i in range(10):
            axs[i].imshow(img[i], cmap="gray", vmin=0, vmax=1)
            axs[i].axis("off")
        plt.show()
    return pred_down


def eval_acc(test_dl, model, mode="fp", mode_kwargs={"noise_var": 0.0}, verbose=True):
    """
    This code is not compatible with older experiments where the labels and some preleared representations are concatenated
    """
    # assess classification accuracy
    correct_count = 0
    total_count = 0
    for x, y in test_dl:
        if mode == "fp":
            pred_up = fp_up(y, model=model, noise_var=mode_kwargs["noise_var"])
        elif mode == "inference":
            pseudo_input = jnp.zeros_like(x)
            pred_up = mode_kwargs["infer_on_batch"](
                mode_kwargs["T"],
                pseudo_input,
                y,
                mode_kwargs["is_up_initialisation"],
                2,
                model=model,
                optim_h=mode_kwargs["optim_h"],
            )[0]
        if isinstance(pred_up, tuple):
            pred_up = pred_up[0]
        correct_count += np.sum(np.argmax(pred_up, axis=1) == np.argmax(x, axis=1))
        total_count += len(x)

    accuracy = correct_count / total_count
    if verbose:
        print(f"Accuracy: {accuracy:.3f}")
    return accuracy


def eval_latent(
    test_dl,
    model,
    mode="inference",
    mode_kwargs={"T": 1000, "optim_h": None, "is_up_initialisation": False},
    verbose=True,
):
    # assess latent space
    latent = []
    test_dl = tqdm(test_dl, desc="Getting latents: ") if verbose else test_dl
    for x, y in test_dl:
        if mode == "fp":
            raise NotImplementedError
        elif mode == "inference":
            infer_on_batch(
                mode_kwargs["T"],
                x,
                y,
                mode_kwargs["is_up_initialisation"],
                0,
                model=model,
                optim_h=mode_kwargs["optim_h"],
            )
        infered_latent = model.vodes[1].get("h")
        latent.append(infered_latent)
    latent = np.array(latent).reshape(-1, latent[0].shape[1])
    return latent


def eval_latent_reconstruction(
    dl, model, mode="fp", mode_kwargs={"noise_var": 0.0}, verbose=True
):
    if mode == "inference":
        model_without_prior = SubModel(model, top_layer_idx=3, bottom_layer_idx=None)
    mse = 0.0
    nm_imgs = 0
    y_recon = None
    # if verbose:
    #     dl = tqdm(dl, desc="Reconstructing: ")
    for idx, (x, y) in enumerate(dl):
        if mode == "fp":
            latent = fp_up_latents(y, model=model)
            reconstructed = fp_down_latents(latent, model=model)
        elif mode == "inference":
            mode_kwargs["infer_on_batch"](
                mode_kwargs["T"],
                x,
                y,
                mode_kwargs["is_up_initialisation"],
                0,
                model=model,
                optim_h=mode_kwargs["optim_h"],
            )[1]
            latent = model.vodes[1].get("h")
            mode_kwargs["infer_on_batch"](
                mode_kwargs["T"],
                latent,
                None,
                False,
                1,
                model=model_without_prior,
                optim_h=mode_kwargs["optim_h"],
            )
            reconstructed = model.vodes[-1].get("h")
        mse += jax.numpy.sum((reconstructed - y) ** 2)
        nm_imgs += y.shape[0]
        if idx == 0:
            y_recon = reconstructed

    if verbose:
        print(
            f"MSE latent reconstruction: {(mse / (nm_imgs * np.prod(model.vodes[-1].shape))):.5f}"
        )

    return mse / (nm_imgs * np.prod(model.vodes[-1].shape)), y_recon


def eval_latent_reconstruction_free_latent(
    dl, model, mode="fp", mode_kwargs={"noise_var": 0.0}, verbose=True
):
    #
    data_dim = np.prod(model.vodes[-1].shape)
    batch_size = len(next(iter(dl))[0])

    err = 0.0
    y_recon = None
    for idx, (x, y) in enumerate(dl):
        if mode == "fp":
            label, latent = fp_up(y, model=model, noise_var=0.0)
            recon = fp_down_w_latent(x, latent, model=model, noise_var=0.0)
        if mode == "inference":
            mode_kwargs["infer_on_batch"](
                mode_kwargs["T"],
                x,
                y,
                mode_kwargs["is_up_initialisation"],
                0,
                model=model,
                optim_h=mode_kwargs["optim_h"],
            )[1]
            latent = model.latent_vode.get("h")
            mode_kwargs["infer_on_batch"](
                mode_kwargs["T"],
                x,
                None,
                False,
                1,
                latent=latent,
                model=model,
                optim_h=mode_kwargs["optim_h"],
            )
            recon = model.vodes[-1].get("h")
        err += jnp.sum((recon - y) ** 2)
        if idx == 0:
            y_recon = recon

    if verbose:
        print(
            f"MSE latent reconstruction: {err / (len(dl) * batch_size * data_dim):.3f}"
        )
    return err / (len(dl) * batch_size * data_dim), y_recon


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

class MNIST_LinearClassifier(nn.Module):
    def __init__(self, rep_size, n_classes):
        super().__init__()
        self.lin = nn.Linear(rep_size, n_classes)

    def forward(self, x):
        x = self.lin(x)
        return x

def test(model, testloader, print_acc=False):
    correct_count, all_count = 0., 0.
    for data, labels in testloader:
        pred = torch.max(torch.exp(model(data)), 1)
        correct = (pred.indices == labels).long()
        correct_count += correct.sum()
        all_count += correct.size(0)
    acc =correct_count / all_count
    if print_acc:
        print("Model Accuracy =", acc)
    return acc

def get_acc(latent, labels, n_classes, verbose):
    dataset = TensorDataset(latent, labels)
    representations = DataLoader(dataset, batch_size=128, shuffle=True)
    classifier = MNIST_LinearClassifier(latent.shape[1], n_classes)
    criterion = nn.CrossEntropyLoss()
    optm = optim.Adam(classifier.parameters(), lr=0.05) #, momentum=0.0

    # train classifier
    # print("training classifier")
    EPOCHS = 50
    best_acc=0.
    for epoch_idx in range(EPOCHS):
        for data, label in representations:
            classifier.zero_grad()
            out = classifier(data)
            optm.zero_grad()
            loss = criterion(out, label)
            loss.backward(retain_graph=True)
            optm.step()
        acc = test(classifier,representations,print_acc=False)
        # if verbose:
        #     print("EPOCH ", epoch_idx, ": accuracy ", acc)
        if acc > best_acc:
            best_acc = acc
    if verbose:
        print("Best accuracy: " + str(best_acc))
    return best_acc, classifier


def eval_latent_decoding_acc(
    dl, dl_labels, model, mode="fp", mode_kwargs={"noise_var": 0.0}, verbose=True
):
    if mode == "inference":
        model_without_prior = SubModel(model, top_layer_idx=3, bottom_layer_idx=None)
    mse = 0.0
    nm_imgs = 0

    n_classes = None

    latents = []
    labels = []
    for idx, ((x, y), (x_labels, y_labels)) in enumerate(zip(dl, dl_labels)):
        if mode == "fp":
            latent = fp_up_latents(y, model=model)
        elif mode == "inference":
            mode_kwargs["infer_on_batch"](
                mode_kwargs["T"],
                x,
                y,
                mode_kwargs["is_up_initialisation"],
                0,
                model=model,
                optim_h=mode_kwargs["optim_h"],
            )[1]
            latent = model.vodes[1].get("h")

        latents.append(latent)
        labels.append(x_labels.argmax(axis=1))
        if n_classes is None:
            n_classes = x_labels.shape[1]
    latents = torch.tensor(np.concatenate(latents, axis=0))
    labels = torch.tensor(np.concatenate(labels, axis=0), dtype=torch.long)

    acc, classifier = get_acc(latents, labels, n_classes, False)

    if verbose:
        print(
            f"Acc decoding of representations: {acc:.5f}"
        )

    return acc


def eval_supervised_specificity(
    model,
    optim_h,
    dl,
    is_up_initialisation,
    nm_samples=256,
    is_supervised=True,
    save=False,
    save_path="test_gen.png",
    is_recalibrate=False,
    median=True,
):
    if median:
        statistic = lambda x, perc=10: np.percentile(x, perc)  # np.median
        infer_for_statistic = infer_on_batch_latent_gen_no_init_median
        energy_for_statistic = energy_per_stream_median
        select = True
    else:
        statistic = np.mean
        infer_for_statistic = infer_on_batch_latent_gen_no_init
        energy_for_statistic = energy_per_stream
        select = False

    ## Get energy of the model for dataset images
    e_up, e_down, e = [], [], []
    for x, y in tqdm(dl, desc="Measuring energy of in distribution data"):
        # measure energy of the trained model
        infer_on_batch(
            50000, x, y, is_up_initialisation, 0, model=model, optim_h=optim_h
        )
        e_b, e_up_b, e_down_b = energy_per_data(
            x, y, model=model
        )  # energy per stream scalled by alphas
        e.append(e_b)
        e_up.append(e_up_b)
        e_down.append(e_down_b)
        # break
    e, e_up, e_down = np.concatenate(e), np.concatenate(e_up), np.concatenate(e_down)
    e, e_up, e_down = (
        statistic(e).item(),
        statistic(e_up).item(),
        statistic(e_down).item(),
    )

    # recalibrate model so that relative energy of up and down are equal
    if is_recalibrate:
        alpha_up, alpha_down = model.alpha_up, model.alpha_down
        model.alpha_up = 1.0
        model.alpha_down = 1.0 * e_up / e_down
        e_up = e_up * model.alpha_up / alpha_up
        e_down = e_down * model.alpha_down / alpha_down

    ## generate images by setting the latent to random values and minimising the energy
    # find number of batches
    batch_size = len(next(iter(dl))[0])
    effective_batch_size = batch_size if not select else batch_size // 2
    nm_batches = np.ceil(nm_samples / effective_batch_size).astype(int)
    latent_dim = model.vodes[0].shape
    if is_supervised:
        inputs = jnp.arange(batch_size * nm_batches) % 10
        inputs = jax.nn.one_hot(inputs, 10)
    else:
        inputs = jnp.zeros((batch_size * nm_batches, *latent_dim))
    inputs = inputs.reshape(nm_batches, batch_size, *latent_dim)

    # setup random initialisation
    tforms = {
        "randn": lambda n, k, v, rkg: jax_random.normal(rkg(), n.shape.get()),
        "thin_uniform": lambda n, k, v, rkg: jax.random.uniform(
            rkg(), shape=(n.shape.get()), minval=-0.1, maxval=0.1
        ),
        "narrow_uniform": lambda n, k, v, rkg: jax.random.uniform(
            rkg(), shape=(n.shape.get()), minval=-1, maxval=1
        ),
        "mid_uniform": lambda n, k, v, rkg: jax.random.uniform(
            rkg(), shape=(n.shape.get()), minval=-3, maxval=3
        ),
        "wide_uniform": lambda n, k, v, rkg: jax.random.uniform(
            rkg(), shape=(n.shape.get()), minval=-5, maxval=5
        ),
        "xav": lambda n, k, v, rkg: jax.random.uniform(
            rkg(),
            shape=(n.shape.get()),
            minval=-jnp.sqrt(6 / n.shape.get()[0]),
            maxval=jnp.sqrt(6 / n.shape.get()[0]),
        ),
        "to_zero": lambda n, k, v, rkg: jnp.zeros(n.shape.get()),
    }
    ruleset = {
        "ff": ("h, u <- u",),
        "latent_gen": ("h, u <- u:to_zero",),
    }
    for v in model.vodes:
        v.ruleset = Ruleset(ruleset, tforms)

    # run inference till energy level is reached starting from ranom image initialisation
    imgs = []
    labels = []
    for input in tqdm(list(zip(inputs)), desc="Generating images: "):
        input = input[0]
        pseudo_data = jax.random.uniform(px.RKG(), shape=y.shape, minval=-1, maxval=1)
        # pseudo_data = jnp.zeros(shape=y.shape)
        # init random hidden states, latent=input and output=pseudo_data
        with pxu.step(model, "latent_gen", clear_params=pxc.VodeParam.Cache):
            initialisation(
                input, pseudo_data, model=model, is_up_initialisation=True
            )  # init direction is not important here
        infer_for_statistic(
            1000000, input, pseudo_data, 1, e_up, e_down, model=model, optim_h=optim_h
        )
        # check energy level
        e_up_gen, e_down_gen = energy_for_statistic(input, pseudo_data, model=model)
        e_up_gen, e_down_gen = e_up_gen.item(), e_down_gen.item()
        print(
            f"Up energy: {e_up_gen}for expected {e_up} and down energy: {e_down_gen} for expected {e_down}"
        )
        if e_up_gen > e_up or e_down_gen > e_down:
            warnings.warn("Energy level not reached")

        if not select:
            imgs.append(np.array(model.vodes[-1].get("h")))
            labels.append(input)
        else:
            # get energy per data point
            e_b, e_up_b, e_down_b = energy_per_data(input, pseudo_data, model=model)
            accepted_images = (e_up_b < e_up) & (e_down_b < e_down)
            imgs.append(np.array(model.vodes[-1].get("h"))[accepted_images])
            labels.append(input[accepted_images])

    imgs = np.concatenate(imgs) / 2 + 0.5
    imgs = np.clip(imgs, 0.0, 1.0)
    labels = np.concatenate(labels)
    shape = (
        (1, 28, 28)
        if model.vodes[-1].shape == (784,) or model.vodes[-1].shape == 784
        else model.vodes[-1].shape
    )
    imgs = imgs.reshape(-1, *shape)
    labels = labels.reshape(-1, *latent_dim)
    if save:
        n_save = 20  # number of images to save per class
        imgs_gen_torch = torch.from_numpy(imgs)
        all_idxs = np.arange(len(labels))
        idxs = np.concatenate(
            [all_idxs[np.argmax(labels, 1) == i][:n_save] for i in range(10)]
        )
        torchvision.utils.save_image(imgs_gen_torch[idxs], save_path, nrow=n_save)
    return imgs[:nm_samples], labels[:nm_samples]
