import numpy as np
import torch
import torch.nn.functional as F
from timeit import default_timer
from torch.optim import Adam
import os
import imageio
import matplotlib.pyplot as plt
import shutil
from autodecoder import SuperCNN, autoencoder, encoder, decoder
from flow_model import real_nvp, glow
import normflow as nf
from my_utils import *
from datasets import *
import sys
from torch.utils.data.sampler import SubsetRandomSampler
from results import evaluator
import PDE_solver
from laplacian_loss import LaplacianPyramidLoss

torch.manual_seed(0)
np.random.seed(0)

FLAGS, unparsed = flags()
epochs_supercnn = FLAGS.epochs_supercnn
epochs_flow = FLAGS.epochs_flow
epochs_aeder = FLAGS.epochs_aeder
flow_depth = FLAGS.flow_depth
latent_dim = FLAGS.latent_dim
batch_size = FLAGS.batch_size
dataset = FLAGS.dataset
gpu_num = FLAGS.gpu_num
exp_desc = FLAGS.exp_desc
supercnn_desc = FLAGS.supercnn_desc
image_size = FLAGS.res
c = FLAGS.c
train_aeder = bool(FLAGS.train_aeder)
train_supercnn = bool(FLAGS.train_supercnn)
train_flow = bool(FLAGS.train_flow)
restore_flow = bool(FLAGS.restore_flow)
run_pde = bool(FLAGS.run_pde)
training_mode = FLAGS.training_mode

enable_cuda = True
device = torch.device('cuda:' + str(gpu_num) if torch.cuda.is_available() and enable_cuda else 'cpu')

color_map = True if dataset == 'cars_sdf' else False

all_experiments = 'experiments/'
if os.path.exists(all_experiments) == False:
    os.mkdir(all_experiments)

# experiment path
exp_path = all_experiments + 'Autoencoder_' + dataset + '_' \
    + str(flow_depth) + '_' + str(latent_dim) + '_' + str(image_size) + '_' + exp_desc

# if os.path.exists(exp_path) == True and remove_all == True:
#     shutil.rmtree(exp_path)

if os.path.exists(exp_path) == False:
    os.mkdir(exp_path)

supercnn_path = os.path.join(exp_path, 'supercnns', supercnn_desc)

if os.path.exists(os.path.join(exp_path, 'supercnns')) == False:
    os.mkdir(os.path.join(exp_path, 'supercnns'))

# if os.path.exists(supercnn_path) == True and remove_supercnn == True:
#     shutil.rmtree(supercnn_path)

if os.path.exists(supercnn_path) == False:
    os.mkdir(supercnn_path)




learning_rate_aeder = 1e-4
learning_rate_supercnn = 1e-4
step_size = 50
gamma = 0.5
num_batch_pixels = 3
batch_pixels = 512
k = 2 # super resolution factor for image generation
lam = 0.01
loss_type = 'style'

# Print the experiment setup:
print('Experiment setup:')
print('---> epochs_aeder: {}'.format(epochs_aeder))
print('---> epochs_supercnn: {}'.format(epochs_supercnn))
print('---> epochs_flow: {}'.format(epochs_flow))
print('---> flow_depth: {}'.format(flow_depth))
print('---> batch_size: {}'.format(batch_size))
print('---> dataset: {}'.format(dataset))
print('---> Learning rate_aeder: {}'.format(learning_rate_aeder))
print('---> experiment path: {}'.format(exp_path))
print('---> latent dim: {}'.format(latent_dim))
print('---> image size: {}'.format(image_size))


# Dataset:
train_set = dataset + '_train'
test_set = dataset + '_test'
train_dataset = Dataset_loader(dataset = train_set ,size = (image_size,image_size), c = c, quantize = False)
test_dataset = Dataset_loader(dataset = test_set ,size = (8*image_size,8*image_size), c = c, quantize = False)
ood_dataset = Dataset_loader(dataset = 'lsun',size = (128,128), c = c, quantize = False)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=40, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=25, num_workers=8)
ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=25, num_workers=8)

ntrain = len(train_loader.dataset)
n_test = len(test_loader.dataset)
n_ood= len(ood_loader.dataset)
print('---> Number of training, test and ood samples: {}, {}, {}'.format(ntrain,n_test, n_ood))
plot_per_num_epoch = 1 if ntrain > 20000 else 20000//ntrain

# Loss
dum_samples = next(iter(test_loader)).to(device)
mse_l = F.mse_loss
pyramid_l = LaplacianPyramidLoss(max_levels=3, channels=c, kernel_size=5,
    sigma=1, device=device, dtype=dum_samples.dtype)
# pyramid_l = F.mse_loss
vgg =Vgg16().to(device)
for param in vgg.parameters():
    param.requires_grad = False

# 1. Training Autoencoder:
enc = encoder(latent_dim = latent_dim, in_res = image_size , c = c).to(device)
dec = decoder(latent_dim = latent_dim, in_res = image_size , c = c).to(device)

aeder = autoencoder(encoder = enc , decoder = dec).to(device)

num_param_aeder= count_parameters(aeder)
print('---> Number of trainable parameters of Autoencoder: {}'.format(num_param_aeder))

optimizer_aeder = Adam(aeder.parameters(), lr=learning_rate_aeder)
scheduler_aeder = torch.optim.lr_scheduler.StepLR(optimizer_aeder, step_size=step_size, gamma=gamma)

checkpoint_autoencoder_path = os.path.join(exp_path, 'autoencoder.pt')
if os.path.exists(checkpoint_autoencoder_path):
    checkpoint_autoencoder = torch.load(checkpoint_autoencoder_path)
    aeder.load_state_dict(checkpoint_autoencoder['model_state_dict'])
    optimizer_aeder.load_state_dict(checkpoint_autoencoder['optimizer_state_dict'])
    print('Autoencoder is restored...')


if train_aeder:

    if plot_per_num_epoch == -1:
        plot_per_num_epoch = epochs_aeder + 1 # only plot in the last epoch
    
    loss_ae_plot = np.zeros([epochs_aeder])
    for ep in range(epochs_aeder):
        aeder.train()
        t1 = default_timer()
        loss_ae_epoch = 0

                    
        loss_type = 'style' if ep < 100 else 'style_mse'
        print(loss_type)
        for image in train_loader:
            
            batch_size = image.shape[0]
            image = image.to(device)
            
            optimizer_aeder.zero_grad()
            image_mat = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2)

            embed = aeder.encoder(image_mat)
            image_recon = aeder.decoder(embed)

            
            # recon_loss = pyramid_l(image_recon , image_mat)
            # recon_loss = mse_l(image_recon.reshape(batch_size, -1) , image_mat.reshape(batch_size, -1) )

            recon_loss = aeder_loss(image_mat, image_recon, loss_type = loss_type,
                pyramid_l = pyramid_l, mse_l = mse_l, vgg = vgg)
            regularization = mse_l(embed, torch.zeros(embed.shape).to(device))
            ae_loss = recon_loss + lam * regularization

            ae_loss.backward()
    
            optimizer_aeder.step()
            loss_ae_epoch += ae_loss.item()


        scheduler_aeder.step()
        t2 = default_timer()

        loss_ae_epoch/= ntrain
        loss_ae_plot[ep] = loss_ae_epoch
        
        plt.plot(np.arange(epochs_aeder)[:ep], loss_ae_plot[:ep], 'o-', linewidth=2)
        plt.title('AE_loss')
        plt.xlabel('epoch')
        plt.ylabel('MSE loss')

        plt.savefig(os.path.join(exp_path, 'Autoencoder_loss.jpg'))
        np.save(os.path.join(exp_path, 'Autoencoder_loss.npy'), loss_ae_plot[:ep])
        plt.close()
        
        torch.save({
                    'model_state_dict': aeder.state_dict(),
                    'optimizer_state_dict': optimizer_aeder.state_dict()
                    }, checkpoint_autoencoder_path)


        samples_folder = os.path.join(exp_path, 'Generated_samples')
        if not os.path.exists(samples_folder):
            os.mkdir(samples_folder)
        image_path_reconstructions = os.path.join(
            samples_folder, 'Reconstructions_aeder')
    
        if not os.path.exists(image_path_reconstructions):
            os.mkdir(image_path_reconstructions)
        
        
        if (ep + 1) % plot_per_num_epoch == 0 or ep + 1 == epochs_aeder:
            sample_number = 25
            ngrid = int(np.sqrt(sample_number))

            images_8k = next(iter(test_loader)).to(device)[:sample_number]
            images_8k = images_8k.reshape(-1, 8*image_size, 8*image_size, c).permute(0,3,1,2)
            image = F.interpolate(images_8k, size = image_size, antialias = True, mode = 'bilinear')

            image_np = image.permute(0,2,3,1).detach().cpu().numpy()

            image_write = image_np[:sample_number].reshape(
                ngrid, ngrid,
                image_size, image_size,c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0
            image_write = image_write.clip(0, 255).astype(np.uint8)


            imageio.imwrite(os.path.join(image_path_reconstructions, '%d_gt.png' % (ep,)),image_write)
            
            
            embed = aeder.encoder(image)
            image_recon = aeder.decoder(embed)
            image_recon_np = image_recon.detach().cpu().numpy().transpose(0,2,3,1)
            image_recon_write = image_recon_np[:sample_number].reshape(
                ngrid, ngrid,
                image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0

            print(image_recon_write.max(), image_recon_write.min())

            image_recon_write = image_recon_write.clip(0, 255).astype(np.uint8)

            imageio.imwrite(os.path.join(image_path_reconstructions, '%d_aeder_recon.png' % (ep,)),
                            image_recon_write)
            
            snr_aeder = SNR(image_np , image_recon_np)


            with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
                        file.write('ep: %03d/%03d | time: %.4f | aeder_loss %.4f | SNR_aeder  %.4f' %(ep, epochs_aeder,t2-t1,
                            loss_ae_epoch, snr_aeder))
                        file.write('\n')

            print('ep: %03d/%03d | time: %.4f | aeder_loss %.4f | SNR_aeder  %.4f' %(ep, epochs_aeder,t2-t1,
                            loss_ae_epoch, snr_aeder))
        

# Training the flow model
nfm = real_nvp(latent_dim = latent_dim, K = flow_depth)
# nfm = glow(1 , 32)
nfm = nfm.to(device)
num_param_nfm = count_parameters(nfm)
print('Number of trainable parametrs of flow: {}'.format(num_param_nfm))

loss_hist = np.array([])
optimizer_flow = torch.optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler_flow = torch.optim.lr_scheduler.StepLR(optimizer_flow, step_size=step_size, gamma=gamma)

# Initialize ActNorm
batch_img = next(iter(train_loader)).to(device)
batch_img = batch_img.reshape(-1, image_size, image_size, c).permute(0,3,1,2)
dummy_samples = aeder.encoder(batch_img)
print(dummy_samples.shape, dummy_samples.max(), dummy_samples.min())
# dummy_samples = model.reference_latents(torch.tensor(0).to(device))
dummy_samples = dummy_samples.view(-1, latent_dim)
# dummy_samples = torch.tensor(dummy_samples).float().to(device)
likelihood = nfm.log_prob(dummy_samples)

checkpoint_flow_path = os.path.join(exp_path, 'flow.pt')
if os.path.exists(checkpoint_flow_path) and restore_flow == True:
    checkpoint_flow = torch.load(checkpoint_flow_path)
    nfm.load_state_dict(checkpoint_flow['model_state_dict'])
    optimizer_flow.load_state_dict(checkpoint_flow['optimizer_state_dict'])
    print('Flow model is restored...')



if train_flow:
    
    for ep in range(epochs_flow):

        nfm.train()
        t1 = default_timer()
        loss_flow_epoch = 0
        for image in train_loader:
            optimizer_flow.zero_grad()
            image = image.to(device)
            image = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2)

            x = aeder.encoder(image)

            # Compute loss
            loss_flow = nfm.forward_kld(x)
            
            if ~(torch.isnan(loss_flow) | torch.isinf(loss_flow)):
                loss_flow.backward()
                optimizer_flow.step()
            
            # Make layers Lipschitz continuous
            # nf.utils.update_lipschitz(nfm, 5)
            
            loss_flow_epoch += loss_flow.item()
            
            # Log loss
            loss_hist = np.append(loss_hist, loss_flow.to('cpu').data.numpy())
        
        scheduler_flow.step()
        t2 = default_timer()
        loss_flow_epoch /= ntrain
        
        torch.save({
                    'model_state_dict': nfm.state_dict(),
                    'optimizer_state_dict': optimizer_flow.state_dict()
                    }, checkpoint_flow_path)
        
        
        if (ep + 1) % plot_per_num_epoch == 0 or ep + 1 == epochs_flow:
            samples_folder = os.path.join(exp_path, 'Generated_samples')
            if not os.path.exists(samples_folder):
                os.mkdir(samples_folder)
            image_path_generated = os.path.join(
                samples_folder, 'generated')
        
            if not os.path.exists(image_path_generated):
                os.mkdir(image_path_generated)
            sample_number = 25
            ngrid = int(np.sqrt(sample_number))
            
            generated_embed, _ = nfm.sample(torch.tensor(sample_number).to(device))
            
            generated_samples = aeder.decoder(generated_embed)
            generated_samples = generated_samples.detach().cpu().numpy().transpose(0,2,3,1)

            generated_samples = generated_samples[:sample_number].reshape(
                ngrid, ngrid,
                image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0
            generated_samples = generated_samples.clip(0, 255).astype(np.uint8)
            
            if color_map:
                # generated_samples = imageio.applyColorMap(generated_samples[:,:,0], cmapy.cmap('seismic'))
                plt.imsave(os.path.join(image_path_generated, 'epoch %d.png' % (ep,)),
                    generated_samples[:,:,0], cmap='seismic')
            else:    
                imageio.imwrite(os.path.join(image_path_generated, 'epoch %d.png' % (ep,)), generated_samples) # training images
            
            with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
                    file.write('ep: %03d/%03d | time: %.4f | ML_loss %.4f' %(ep, epochs_flow, t2-t1, loss_flow_epoch))
                    file.write('\n')
    
            print('ep: %03d/%03d | time: %.4f | ML_loss %.4f' %(ep, epochs_flow, t2-t1, loss_flow_epoch))



# Training SuperCNN:
model = SuperCNN(c=c).to(device)
num_param_supercnn = count_parameters(model)
print('---> Number of trainable parameters of supercnn: {}'.format(num_param_supercnn))

optimizer_supercnn = Adam(model.parameters(), lr=learning_rate_supercnn)
scheduler_supercnn = torch.optim.lr_scheduler.StepLR(optimizer_supercnn, step_size=step_size, gamma=gamma)

checkpoint_exp_path = os.path.join(supercnn_path, 'supercnn.pt')
if os.path.exists(checkpoint_exp_path):
    checkpoint_supercnn = torch.load(checkpoint_exp_path)
    model.load_state_dict(checkpoint_supercnn['model_state_dict'])
    optimizer_supercnn.load_state_dict(checkpoint_supercnn['optimizer_state_dict'])
    print('supercnn is restored...')

if train_supercnn:

    if plot_per_num_epoch == -1:
        plot_per_num_epoch = epochs_supercnn + 1 # only plot in the last epoch
    
    loss_supercnn_plot = np.zeros([epochs_supercnn])
    for ep in range(epochs_supercnn):
        model.train()
        t1 = default_timer()
        loss_supercnn_epoch = 0

        for image in train_loader:
            
            batch_size = image.shape[0]
            image = image.to(device)
            
            for i in range(num_batch_pixels):
                image_mat = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2)

                image_recon = aeder.decoder(aeder.encoder(image_mat))

                image_high, image_low, image_size_high = training_strategy(image_mat, image_size,
                    factor = k , mode = training_mode, image_recon = image_recon)
                coords = get_mgrid(image_size_high).reshape(-1, 2)
                coords = torch.unsqueeze(coords, dim = 0)
                coords = coords.expand(batch_size , -1, -1).to(device)


                image_high = image_high.permute(0,2,3,1).reshape(-1, image_size_high * image_size_high, c)
                optimizer_supercnn.zero_grad()
                pixels = np.random.randint(low = 0, high = image_size_high**2, size = batch_pixels)
                batch_coords = coords[:,pixels]
                batch_image = image_high[:,pixels]

                out = model(batch_coords, image_low)
                mse_loss = mse_l(out.reshape(batch_size, -1) , batch_image.reshape(batch_size, -1) )
                total_loss = mse_loss 
                total_loss.backward()
                optimizer_supercnn.step()
                loss_supercnn_epoch += total_loss.item()



        scheduler_supercnn.step()
        t2 = default_timer()
        loss_supercnn_epoch/= ntrain
        loss_supercnn_plot[ep] = loss_supercnn_epoch
        
        plt.plot(np.arange(epochs_supercnn)[:ep] , loss_supercnn_plot[:ep], 'o-', linewidth=2)
        plt.title('supercnn_loss')
        plt.xlabel('epoch')
        plt.ylabel('MSE loss')
        plt.savefig(os.path.join(supercnn_path, 'supercnn_loss.jpg'))
        np.save(os.path.join(exp_path, 'supercnn_loss.npy'), loss_supercnn_plot[:ep])
        plt.close()
        
        torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer_supercnn.state_dict()
                    }, checkpoint_exp_path)
        
        
        samples_folder = os.path.join(supercnn_path, 'Generated_samples')
        if not os.path.exists(samples_folder):
            os.mkdir(samples_folder)
        image_path_reconstructions = os.path.join(
            samples_folder, 'Reconstructions')
    
        if not os.path.exists(image_path_reconstructions):
            os.mkdir(image_path_reconstructions)
        
        if (ep + 1) % plot_per_num_epoch == 0 or (ep + 1) == epochs_supercnn:


            evaluator('generative', test_loader, model, device, image_size, c,
                k, image_path_reconstructions, supercnn_path, ep,
                t1, t2, epochs_supercnn, loss_supercnn_epoch,aeder)


samples_folder = os.path.join(supercnn_path, 'Generated_samples')
if not os.path.exists(samples_folder):
    os.mkdir(samples_folder)
image_path_reconstructions = os.path.join(
    samples_folder, 'Reconstructions')

if not os.path.exists(image_path_reconstructions):
    os.mkdir(image_path_reconstructions)


evaluator('generative', test_loader, model, device, image_size, c,
    k, image_path_reconstructions, supercnn_path, -1,
    0, 0, epochs_supercnn, 0,aeder)


if run_pde:

    PDE_solver.poisson_first_order(supercnn_path, model, test_loader, train_loader, 'flow', aeder, nfm)





    
