import numpy as np
import torch
import torch.nn.functional as F
from timeit import default_timer
from torch.optim import Adam
import os
import matplotlib.pyplot as plt
import shutil
from autodecoder import SuperCNN, SuperCNNv2
from my_utils import *
from datasets import *
import sys
from torch.utils.data.sampler import SubsetRandomSampler
from results import evaluator
import PDE_solver

torch.manual_seed(0)
np.random.seed(0)

FLAGS, unparsed = flags()
epochs_supercnn = FLAGS.epochs_supercnn
batch_size = FLAGS.batch_size
dataset = FLAGS.dataset
gpu_num = FLAGS.gpu_num
exp_desc = FLAGS.exp_desc
image_size = FLAGS.res
c = FLAGS.c
remove_all = bool(FLAGS.remove_all)
train_supercnn = bool(FLAGS.train_supercnn)
run_pde = bool(FLAGS.run_pde)
training_mode = FLAGS.training_mode

enable_cuda = True
device = torch.device('cuda:' + str(gpu_num) if torch.cuda.is_available() and enable_cuda else 'cpu')

color_map = True if dataset == 'cars_sdf' else False

all_experiments = 'experiments/'
if os.path.exists(all_experiments) == False:
    os.mkdir(all_experiments)

# experiment path
exp_path = all_experiments + 'Supercnn_' + dataset + '_' \
    + str(image_size) + '_' + exp_desc

if os.path.exists(exp_path) == True and remove_all == True:
    shutil.rmtree(exp_path)

if os.path.exists(exp_path) == False:
    os.mkdir(exp_path)



learning_rate = 1e-4
step_size = 50
gamma = 0.5
# myloss = F.mse_loss
myloss = F.l1_loss

num_batch_pixels = 3
batch_pixels = 512
k = 2 # super resolution factor for image generation
k_test = 8 if dataset == 'celeba-hq' else 4

# Print the experiment setup:
print('Experiment setup:')
print('---> epochs_supercnn: {}'.format(epochs_supercnn))
print('---> batch_size: {}'.format(batch_size))
print('---> dataset: {}'.format(dataset))
print('---> Learning rate: {}'.format(learning_rate))
print('---> experiment path: {}'.format(exp_path))
print('---> image size: {}'.format(image_size))


# Dataset:
train_set = dataset + '_train'
test_set = dataset + '_test'
train_dataset = Dataset_loader(dataset = train_set ,size = (image_size,image_size), c = c)
test_dataset = Dataset_loader(dataset = test_set ,size = (k_test*image_size,k_test*image_size), c = c)
ood_dataset = Dataset_loader(dataset = 'lsun',size = (2*image_size,2*image_size), c = c)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=24, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, num_workers=24)
ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=300, num_workers=24)

ntrain = len(train_loader.dataset)
n_test = len(test_loader.dataset)
n_ood= len(ood_loader.dataset)
print('---> Number of training, test and ood samples: {}, {}, {}'.format(ntrain,n_test, n_ood))


plot_per_num_epoch = 10 if ntrain > 20000 else 20000//ntrain
# Training SuperCNN:
model = SuperCNN(c=c).to(device)
model = torch.nn.DataParallel(model)
num_param_supercnn = count_parameters(model)
print('---> Number of trainable parameters of supercnn: {}'.format(num_param_supercnn))

optimizer_supercnn = Adam(model.parameters(), lr=learning_rate)
scheduler_supercnn = torch.optim.lr_scheduler.StepLR(optimizer_supercnn, step_size=step_size, gamma=gamma)

checkpoint_exp_path = os.path.join(exp_path, 'supercnn.pt')
if os.path.exists(checkpoint_exp_path):
    checkpoint_supercnn = torch.load(checkpoint_exp_path)
    model.load_state_dict(checkpoint_supercnn['model_state_dict'])
    optimizer_supercnn.load_state_dict(checkpoint_supercnn['optimizer_state_dict'])
    print('supercnn is restored...')

if train_supercnn:

    if plot_per_num_epoch == -1:
        plot_per_num_epoch = epochs_supercnn + 1 # only plot in the last epoch
    
    loss_supercnn_plot = np.zeros([epochs_supercnn])
    for ep in range(epochs_supercnn):
        model.train()
        t1 = default_timer()
        loss_supercnn_epoch = 0

        for image in train_loader:
            
            batch_size = image.shape[0]
            image = image.to(device)
            
            for i in range(num_batch_pixels):
                image_mat = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2)

                image_high, image_low, image_size_high = training_strategy(image_mat, image_size, factor = k , mode = training_mode)
                coords = get_mgrid(image_size_high).reshape(-1, 2)
                coords = torch.unsqueeze(coords, dim = 0)
                coords = coords.expand(batch_size , -1, -1).to(device)

                image_high = image_high.permute(0,2,3,1).reshape(-1, image_size_high * image_size_high, c)
                optimizer_supercnn.zero_grad()
                pixels = np.random.randint(low = 0, high = image_size_high**2, size = batch_pixels)
                batch_coords = coords[:,pixels]
                batch_image = image_high[:,pixels]

                out = model(batch_coords, image_low)
                mse_loss = myloss(out.reshape(batch_size, -1) , batch_image.reshape(batch_size, -1) )
                total_loss = mse_loss 
                total_loss.backward()
                optimizer_supercnn.step()
                loss_supercnn_epoch += total_loss.item()


        # print(model.ws1, model.ws2, model.ws1.grad, model.ws2.grad)
        scheduler_supercnn.step()
        t2 = default_timer()
        loss_supercnn_epoch/= ntrain
        loss_supercnn_plot[ep] = loss_supercnn_epoch
        
        plt.plot(np.arange(epochs_supercnn)[:ep] , loss_supercnn_plot[:ep], 'o-', linewidth=2)
        plt.title('supercnn_loss')
        plt.xlabel('epoch')
        plt.ylabel('MSE loss')
        plt.savefig(os.path.join(exp_path, 'supercnn_loss.jpg'))
        np.save(os.path.join(exp_path, 'supercnn_loss.npy'), loss_supercnn_plot[:ep])
        plt.close()
        
        torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer_supercnn.state_dict()
                    }, checkpoint_exp_path)
        
        
        samples_folder = os.path.join(exp_path, 'Generated_samples')
        if not os.path.exists(samples_folder):
            os.mkdir(samples_folder)
        image_path_reconstructions = os.path.join(
            samples_folder, 'Reconstructions')
    
        if not os.path.exists(image_path_reconstructions):
            os.mkdir(image_path_reconstructions)
        
        if ep % plot_per_num_epoch == 0 or (ep + 1) == epochs_supercnn:


            evaluator(training_mode, 'test', test_loader, model, device, image_size, c,
                k_test, image_path_reconstructions, exp_path, ep,
                t1, t2, epochs_supercnn, loss_supercnn_epoch)
            if dataset == 'celeba-hq':

                evaluator(training_mode, 'ood', ood_loader, model, device, image_size, c,
                    k_test, image_path_reconstructions, exp_path, ep,
                    t1, t2, epochs_supercnn, loss_supercnn_epoch)


print('testing...')
samples_folder = os.path.join(exp_path, 'Generated_samples')
if not os.path.exists(samples_folder):
    os.mkdir(samples_folder)
image_path_reconstructions = os.path.join(
    samples_folder, 'Reconstructions')

if not os.path.exists(image_path_reconstructions):
    os.mkdir(image_path_reconstructions)


evaluator(training_mode, 'test', test_loader, model, device, image_size, c,
    k_test, image_path_reconstructions, exp_path, -1,
    0, 0, epochs_supercnn, 0)

if dataset == 'celeba-hq':
    evaluator(training_mode, 'ood', ood_loader, model, device, image_size, c,
        k_test, image_path_reconstructions, exp_path, -1,
        0, 0, epochs_supercnn, 0)


if run_pde:

    PDE_solver.poisson_first_order(model, test_loader, train_loader)





    
