import numpy as np
from .utils import freeze, upsample
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch import autograd

def prepare_ghost_imgs(Y, X, upsample, G, D, mean=0.5, std=0.5):
    freeze(G); freeze(D);
    with torch.no_grad():
        G_Y = G(Y)
        
    D_G_Y_shape = D(G_Y).shape
    G_Y.requires_grad_(True)
    D_grad_G_Y = autograd.grad(
        D(G_Y), G_Y,
        grad_outputs=torch.ones(G_Y.size(0), 1).to(G_Y),
        create_graph=False, retain_graph=False
    )[0].detach()
    with torch.no_grad():
        D_push_G_Y = G_Y - np.prod(G_Y[0].shape) * D_grad_G_Y
    
    X.requires_grad_(True)
    D_grad_X = autograd.grad(
        D(X), X,
        grad_outputs=torch.ones(X.size(0), 1).to(X),
        create_graph=False, retain_graph=False
    )[0].detach()
    
    with torch.no_grad():
        D_push_X = X - np.prod(X[0].shape) * D_grad_X
    
    with torch.no_grad():
        up_Y = upsample(Y).permute(0, 2, 3, 1).detach().to('cpu').mul(std).add(mean).numpy().clip(0,1)
        G_Y = G_Y.permute(0, 2, 3, 1).detach().to('cpu').mul(std).add(mean).numpy().clip(0,1)
        D_push_G_Y = D_push_G_Y.permute(0, 2, 3, 1).detach().to('cpu').mul(std).add(mean).numpy().clip(0,1)
        X = X.permute(0, 2, 3, 1).detach().to('cpu').mul(std).add(mean).numpy().clip(0,1)
        D_push_X = D_push_X.permute(0, 2, 3, 1).detach().to('cpu').mul(0.5).mul(std).add(mean).numpy().clip(0,1)
    
    imgs = np.concatenate((up_Y, G_Y, D_push_G_Y, X, D_push_X), axis=1)
    for i in range(imgs.shape[0]):
        if i == 0:
            imgs_new = imgs[0]
        else:
            imgs_new = np.concatenate((imgs_new, imgs[i]), axis=1)
    torch.cuda.empty_cache();
    return imgs_new

def prepare_train_imgs_for_plotting(Y, upsample, G, mean=0.5, std=0.5):
    freeze(G)
    with torch.no_grad():
        G_Y = G(Y) # model output
        G0_Y = upsample(Y) # upsampled version
    
    G0_Y = G0_Y.permute(0, 2, 3, 1).detach().cpu()
    G_Y = G_Y.permute(0, 2, 3, 1).detach().cpu()

    with torch.no_grad():
        imgs = torch.stack([G0_Y, G_Y]).mul(std).add(mean).numpy().clip(0,1)
        
    imgs = np.concatenate((imgs[0], imgs[1]), axis=1)
    for i in range(imgs.shape[0]):
        if i == 0:
            imgs_new = imgs[0]
        else:
            imgs_new = np.concatenate((imgs_new, imgs[i]), axis=1)
            
    return imgs_new

def prepare_imgs_for_plotting(Y, X, upsample, G, mean=0.5, std=0.5):
    freeze(G)
    with torch.no_grad():
        G_Y = G(Y) # model output
        G0_Y = upsample(Y) # upsampled version
    
    G0_Y = G0_Y.permute(0, 2, 3, 1).detach().cpu()
    G_Y = G_Y.permute(0, 2, 3, 1).detach().cpu()
    X = X.permute(0, 2, 3, 1).detach().cpu()

    with torch.no_grad():
        imgs = torch.stack([G0_Y, G_Y, X]).mul(std).add(mean).numpy().clip(0,1)
        
    imgs = np.concatenate((imgs[0], imgs[1], imgs[2]), axis=1)
    for i in range(imgs.shape[0]):
        if i == 0:
            imgs_new = imgs[0]
        else:
            imgs_new = np.concatenate((imgs_new, imgs[i]), axis=1)
            
    return imgs_new