import torch
import numpy as np
import torch
from torch import distributions
import torchvision
import random
import os


def add_imgs(imgs, outdir, nrow=8):
    imgs = imgs / 2 + 0.5
    imgs.clip_(0.0, 1.0)
    imgs = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=0.8)
    torchvision.utils.save_image(imgs, outdir, nrow=nrow, pad_value=0.8)


def imgs_grid(imgs, nrow=8):
    imgs = imgs / 2 + 0.5
    imgs.clip_(0.0, 1.0)
    img_out = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=1)
    return img_out


def test_incomplete_inference(test_data, vae_net, p_list_ST, r_ind, ddp):

    with torch.no_grad():
        N, P, L = test_data.shape[0], len(p_list_ST), r_ind.shape[1] 
        ids_restore = torch.argsort(r_ind, dim=1)

        mask_st_all = torch.zeros([P, L], device=test_data.device)
        for ii, STratio in enumerate(p_list_ST):
            len_st = int(L * STratio)
            mask_st_all[ii, :len_st] = 1
            mask_st_all[ii] = torch.gather(mask_st_all[ii], dim=0, index=ids_restore[0])

        with torch.no_grad():
            imgs_all = []
            for dataj in test_data:
                img = dataj.unsqueeze(0)  

                if ddp:
                    img_masked = vae_net.module.unpatchify(
                        vae_net.module.patchify(img) * mask_st_all.unsqueeze(-1)
                    )
                    imgs_all.append(img_masked)

                    mu_z, lstd_z = vae_net.module.forward_encoder(img, mask_st=mask_st_all)
                    recon_ospt_img, pred_img = vae_net.module.forward_decoder(zqst=mu_z, imgs=img,
                                                                              mask_s=torch.zeros_like(mask_st_all),
                                                                              toimg=True)
                else:
                    img_masked = vae_net.unpatchify(
                        vae_net.patchify(img) * mask_st_all.unsqueeze(-1)
                    )
                    imgs_all.append(img_masked)

                    mu_z, lstd_z = vae_net.forward_encoder(img, mask_st=mask_st_all)
                    recon_ospt_img, pred_img = vae_net.forward_decoder(zqst=mu_z, imgs=img,
                                                                       mask_s=torch.zeros_like(mask_st_all),
                                                                       toimg=True)

                imgs_all.append(pred_img)

            imgs_all = torch.cat(imgs_all, dim=0)

    return imgs_all


def test_inpainting(test_data, vae_net, p_list_ST, r_ind, ddp):

    with torch.no_grad():
        N, P, L = test_data.shape[0], len(p_list_ST), r_ind.shape[1]  
        ids_restore = torch.argsort(r_ind, dim=1)

        mask_s = torch.zeros([P, L], device=test_data.device)
        for ii, STratio in enumerate(p_list_ST):
            len_st = int(L * STratio)
            mask_s[ii, :len_st] = 1
            mask_s[ii] = torch.gather(mask_s[ii], dim=0, index=ids_restore[0])

        with torch.no_grad():
            imgs_all = []
            for dataj in test_data:
                img = dataj.unsqueeze(0)  

                if ddp:
                    img_masked = vae_net.module.unpatchify(
                        vae_net.module.patchify(img) * mask_s.unsqueeze(-1)
                    )
                    imgs_all.append(img_masked)

                    mu_z, lstd_z = vae_net.module.forward_encoder(img, mask_st=mask_s)
                    recon_ospt_img, pred_img = vae_net.module.forward_decoder(zqst=mu_z, imgs=img,
                                                                              mask_s=mask_s,
                                                                              toimg=True)
                else:
                    img_masked = vae_net.unpatchify(
                        vae_net.patchify(img) * mask_s.unsqueeze(-1)
                    )
                    imgs_all.append(img_masked)

                    mu_z, lstd_z = vae_net.forward_encoder(img, mask_st=mask_s)
                    recon_ospt_img, pred_img = vae_net.forward_decoder(zqst=mu_z, imgs=img,
                                                                       mask_s=mask_s,
                                                                       toimg=True)

                imgs_all.append(recon_ospt_img)

            imgs_all = torch.cat(imgs_all, dim=0)

    return imgs_all


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def test_classification(vae_net, data_loader_train, data_loader_test, STratio=1., train_epoch=5, device=None):
    seed_torch(seed=1029)
    y_dim = len(data_loader_train.sampler.data_source.classes)
    L = vae_net.num_patches

    embed_dim = 256

    fea_dim = vae_net.z_dim
    classifier = torch.nn.Sequential(  
        torch.nn.Linear(fea_dim, y_dim),
    ).to(device)
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.RMSprop(classifier.parameters(), lr=1e-3)

    for epoch in range(train_epoch):
        for data_iter_step, (samples, y) in enumerate(data_loader_train):
            real_imgs = samples.to(device, non_blocking=True)
            y = y.to(device)
            batch_size = real_imgs.shape[0]

            mask_s = torch.zeros([batch_size, L], device=device)
            mask_s[:, :int(L * STratio)] = 1
            permindx = torch.stack([torch.randperm(L) for _ in range(batch_size)], dim=0).to(device)
            mask_s = torch.gather(mask_s, dim=1, index=permindx)

            with torch.no_grad():
                mu_z, lstd_z = vae_net.forward_encoder(real_imgs, mask_st=mask_s)

            ylogits = classifier(mu_z)
            loss = criterion(ylogits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('STratio = ' + str(STratio) + ', Training Epoch ' + str(epoch) + ', loss = ' + str(loss.item()))

    seed_torch(seed=1029)
    numacc = numall = 0
    for data_iter_step, (samples, y) in enumerate(data_loader_test):
        real_imgs = samples.to(device, non_blocking=True)
        y = y.to(device)
        batch_size = real_imgs.shape[0]

        mask_s = torch.zeros([batch_size, L], device=device)
        mask_s[:, :int(L * STratio)] = 1
        permindx = torch.stack([torch.randperm(L) for _ in range(batch_size)], dim=0).to(device)
        mask_s = torch.gather(mask_s, dim=1, index=permindx)

        with torch.no_grad():
            mu_z, lstd_z = vae_net.forward_encoder(real_imgs, mask_st=mask_s)
            ylogits = classifier(mu_z)

            numacc = numacc + (ylogits.argmax(dim=1) == y).sum()
            numall = numall + y.numel()
    ACC = (numacc / numall).item()
    print('STratio = ' + str(STratio) + ', Training Epoch = ' + str(epoch) + ', Testing accuracy = ' + str(ACC))

    return ACC


def test_generation(vae_net, N=100, device=None):

    with torch.no_grad():
        L = vae_net.num_patches
        imgs = torch.zeros(N, 3, vae_net.img_size, vae_net.img_size, device=device)

        mask_st_all = torch.zeros([N, L], device=device)

        mu_z, lstd_z = vae_net.forward_encoder(imgs, mask_st=mask_st_all)
        z = mu_z + lstd_z * torch.randn_like(lstd_z)
        recon_ospt_img, pred_img = vae_net.forward_decoder(zqst=z, imgs=imgs,
                                                           mask_s=torch.zeros_like(mask_st_all),
                                                           toimg=True)
    return pred_img


def patch_test_reconstruction(test_data, patch_ae, vae_net, p_list_ST, r_ind, ddp, bar_std=1.):

    with torch.no_grad():
        N, P, L = test_data.shape[0], len(p_list_ST), r_ind.shape[1] 
        ids_restore = torch.argsort(r_ind, dim=1)

        mask_st_all = torch.zeros([P, L], device=test_data.device)
        for ii, STratio in enumerate(p_list_ST):
            len_st = int(L * STratio)
            mask_st_all[ii, :len_st] = 1
            mask_st_all[ii] = torch.gather(mask_st_all[ii], dim=0, index=ids_restore[0])

        with torch.no_grad():
            imgs_all = []
            for dataj in test_data:
                img = dataj.unsqueeze(0)  

                if ddp:
                    img_masked = vae_net.module.unpatchify(
                        vae_net.module.patchify(img) * mask_st_all.unsqueeze(-1)
                    )
                    imgs_all.append(img_masked)

                    _, bar_imgs = patch_ae.module.encoder(img)  
                    bar_imgs = bar_imgs.reshape([1, patch_ae.num_patch, patch_ae.latent_dim])  
                    mu_z, _ = vae_net.module.forward_encoder(bar_imgs, mask_st=mask_st_all)
                    mu_xp, _ = vae_net.module.forward_decoder(zqst=mu_z, bar_imgs=bar_imgs,
                                                              mask_s=torch.zeros_like(mask_st_all))  
                    pred_img = patch_ae.module.decoder(mu_xp, bsz=P, to_img=True)
                else:
                    img_masked = vae_net.unpatchify(
                        vae_net.patchify(img) * mask_st_all.unsqueeze(-1)
                    )
                    imgs_all.append(img_masked)

                    _, bar_imgs = patch_ae.encoder(img)  
                    bar_imgs = bar_imgs.reshape([1, patch_ae.num_patch, patch_ae.latent_dim])  

                    bar_imgs = bar_imgs + bar_std * torch.randn_like(bar_imgs)

                    mu_z, _ = vae_net.forward_encoder(bar_imgs, mask_st=mask_st_all)
                    mu_xp, _ = vae_net.forward_decoder(zqst=mu_z, bar_imgs=bar_imgs,
                                                       mask_s=torch.zeros_like(mask_st_all))  
                    mu_xp = mu_xp.view(-1, mu_xp.shape[-1])  
                    pred_img = patch_ae.decoder(mu_xp, bsz=P, to_img=True)

                imgs_all.append(pred_img)

            imgs_all = torch.cat(imgs_all, dim=0)

    return imgs_all


def test_jointELBO(data_loader_test, vae_net, betaKL, device, ddp):

    with torch.no_grad():
        losscum = NLLcum = KLlosscum = nsamcum = 0

        for data_iter_step, (samples, y) in enumerate(data_loader_test):
            real_imgs = samples.to(device, non_blocking=True)
            y = y.to(device)
            nsam = real_imgs.shape[0]

            if ddp:
                loss, NLL, KLloss = vae_net.module(real_imgs, Sratio=0., Tratio=1., betaKL=betaKL)
            else:
                loss, NLL, KLloss = vae_net(real_imgs, Sratio=0., Tratio=1., betaKL=betaKL)

            losscum = losscum + loss * nsam
            NLLcum = NLLcum + NLL * nsam
            KLlosscum = KLlosscum + KLloss * nsam
            nsamcum = nsamcum + nsam

        test_jointELBO = losscum / nsamcum
        test_NLL = NLLcum / nsamcum
        test_KLloss = KLlosscum / nsamcum

    return test_jointELBO, test_NLL, test_KLloss
