import torch
import numpy as np
import torch
from torch import distributions
import torchvision


def add_imgs(imgs, outdir, nrow=8):
    imgs = imgs / 2 + 0.5
    imgs.clip_(0.0, 1.0)
    imgs = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=0.8)
    torchvision.utils.save_image(imgs, outdir, nrow=nrow, pad_value=0.8)


def imgs_grid(imgs, nrow=8):
    imgs = imgs / 2 + 0.5
    imgs.clip_(0.0, 1.0)
    img_out = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=1)
    return img_out

def test_reconstruction(test_data, ae_net, noise_std, ddp):

    with torch.no_grad():
        imgs_all = []
        imgs_all.append(test_data)

        if ddp:
            patch_img, z = ae_net.module.encoder(test_data)

            gauss_noise = noise_std * torch.randn_like(z)
            noise_z = z + gauss_noise

            recon_img = ae_net.module.decoder(noise_z, to_img=True)
        else:
            patch_img, z = ae_net.encoder(test_data)
            
            gauss_noise = noise_std * torch.randn_like(z)
            noise_z = z + gauss_noise

            recon_img = ae_net.decoder(noise_z, test_data.shape[0], to_img=True)
                
        imgs_all.append(recon_img)

        imgs_all = torch.cat(imgs_all, dim=0)

    return imgs_all

