import torch
import torchvision.utils
import numpy as np

import wandb


def wandb_plot_elbos(results):
    epochs = list(range(1, len(results["train_elbo_avg"]) + 1))

    wandb.log(
        {
            "ELBOs/Gap": wandb.plot.line_series(
                xs=epochs,
                ys=[results["train_elbo_avg"], results["test_elbo_avg"]],
                keys=["Train", "Test"],
                title="ELBOs",
                xname="Epochs",
            )
        }
    )


def wandb_plot_recon_and_sample(model, test_dataloader, device):
    # Reconstruction test
    print('Testing...(reconstruction)')
    model.eval()
    
    dataiter = iter(test_dataloader)
    images, labels = next(dataiter)

    images = images[:10]

    # This function takes as an input the images to reconstruct
    # and the name of the model with which the reconstructions
    # are performed
    def to_img(x):
        x = x.clamp(0, 1)
        return x


    images_o = torchvision.utils.make_grid(images[:9],3,3)
    images_o = to_img(images_o).numpy()
    images_o = np.transpose(images_o, (1, 2, 0))
    images_o = wandb.Image(images_o, caption="Original Images")
    wandb.log({"original": images_o})

    with torch.no_grad():
            images_r = images.to(device)
            images_r = model.reconstruction(images_r)
            images_r = images_r.cpu()
            images_r = torchvision.utils.make_grid(images_r[:9],3,3)
            images_r = to_img(images_r).numpy()
            images_r = np.transpose(images_r, (1, 2, 0))

    images_r = wandb.Image(images_r, caption="Reconstructions")
    wandb.log({"reconstruction": images_r})

    # Sampling
    print('Testing...(sampling)')
    with torch.no_grad():
        # sample images
        img_samples = model.sample_x(num=25)
        img_samples = img_samples.cpu()
        img_samples = torchvision.utils.make_grid(img_samples,5,5)
        img_samples = to_img(img_samples).numpy()
        img_samples = np.transpose(img_samples, (1, 2, 0))

    img_samples = wandb.Image(img_samples, caption="Sampls")
    wandb.log({"Samples": img_samples})
