import torch
import torch.nn as nn
import torch.optim as optim
import utils
from torchvision.utils import save_image
import sys
from tqdm import tqdm
sys.path.append('/models')
from models.vae import VAE
import torchvision.transforms as transforms

import argparse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()

parser.add_argument("--hidden_channel", type=int, default=32)
parser.add_argument("--n_res_channel", type=int, default=16)
parser.add_argument("--n_res_block", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--scq", choices=('True','False') , default='True')
parser.add_argument("--local", choices=('True','False'), default='True')

args = parser.parse_args()

def rev_normalize(images):
        invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                        std = [ 1/0.5, 1/0.5, 1/0.5 ]),
                                    transforms.Normalize(mean = [ -0.5, -0.5, -0.5 ],
                                                        std = [ 1., 1., 1. ]),])
        return invTrans(images)


def train_loop(j, n_epochs):

    results = {
        'epoch': [],
        'training_recon_errors': [],
        'validation_recon_errors': [],
        'training_quant_errors': [],
        'validation_quant_errors': [],
    }

    # Initialize model
    model = VAE(**hyperparameters)
    model.to(device)

    """
    Set up optimizer and training loop
    """
    optimizer = optim.Adam(model.parameters(), learning_rate_model, amsgrad=True)

    for epoch in range(n_epochs+1):
        training_recon_loss_per_epoch = 0
        training_quant_error_per_epoch = 0
        validation_recon_loss_per_epoch = 0
        validation_quant_error_per_epoch = 0
        training_perplexity_per_epoch = 0
        validation_perplexity_per_epoch = 0
        
        # evaluation loop       
        model.eval()
        for i, (x,_) in enumerate(tqdm(validation_loader)):
            x = x.to(device)
            embedding_loss, x_hat, perplexity = model(x)
            recon_loss = torch.mean((x_hat - x)**2)
            validation_recon_loss_per_epoch += recon_loss.detach()
            validation_quant_error_per_epoch += embedding_loss.detach()
            validation_perplexity_per_epoch += perplexity.detach()
        validation_recon_loss_per_epoch /= len(validation_loader)
        validation_quant_error_per_epoch /= len(validation_loader)
        validation_perplexity_per_epoch /= len(validation_loader)

        # save last validation batch of original/reconstructed images at each epoch
        if epoch % 1 == 0:
            if quant_type == 'scq':
                    filename_original = f'Results_{dataset}/Images/SCQ_Original_image_epoch{epoch}.pdf'
                    filename_reconstructed = f'Results_{dataset}/Images/SCQ_Reconstructed_image_epoch{epoch}.pdf'
            elif quant_type == 'vq':
                    filename_original = f'Results_{dataset}/Images/VQ_Original_image_epoch{epoch}.pdf'
                    filename_reconstructed = f'Results_{dataset}/Images/VQ_Reconstructed_image_epoch{epoch}.pdf'
            elif quant_type == 'vqreplace':
                filename_original = f'Results_{dataset}/Images/VQReplace_Original_image_epoch{epoch}.pdf'
                filename_reconstructed = f'Results_{dataset}/Images/VQReplace_Reconstructed_image_epoch{epoch}.pdf'
            elif quant_type == 'vqaffineopt':
                filename_original = f'Results_{dataset}/Images/VQAffineOpt_Original_image_epoch{epoch}.pdf'
                filename_reconstructed = f'Results_{dataset}/Images/VQAffineOpt_Reconstructed_image_epoch{epoch}.pdf'
            elif quant_type == 'vqreplaceaffineopt':
                filename_original = f'Results_{dataset}/Images/VQReplaceAffineOpt_Original_image_epoch{epoch}.pdf'
                filename_reconstructed = f'Results_{dataset}/Images/VQReplaceAffineOpt_Reconstructed_image_epoch{epoch}.pdf'
            elif quant_type == 'rv':
                filename_original = f'Results_{dataset}/Images/RV_Original_image_epoch{epoch}.pdf'
                filename_reconstructed = f'Results_{dataset}/Images/RV_Reconstructed_image_epoch{epoch}.pdf'

            else:
                filename_original = f'Results_{dataset}/Images/Gumbel_Original_image_epoch{epoch}.pdf'
                filename_reconstructed = f'Results_{dataset}/Images/Gumbel_Reconstructed_image_epoch{epoch}.pdf'

            save_image(rev_normalize(x), fp=filename_original)
            save_image(rev_normalize(x_hat), fp=filename_reconstructed)

        # training loop
        model.train()
        for i, (x,_) in enumerate(tqdm(training_loader)):
            x = x.to(device)
            optimizer.zero_grad()
            embedding_loss, x_hat, perplexity = model(x)
            recon_loss = torch.mean((x_hat - x)**2)
            

            training_recon_loss_per_epoch += recon_loss.detach()
            training_quant_error_per_epoch += embedding_loss.detach()
            training_perplexity_per_epoch += perplexity.detach() 
            
            loss = recon_loss + embedding_loss
            loss.backward()
            optimizer.step()

        training_recon_loss_per_epoch /= len(training_loader)
        training_quant_error_per_epoch /= len(training_loader)
        training_perplexity_per_epoch /= len(training_loader) 
    
        
        results["training_recon_errors"].append(training_recon_loss_per_epoch.cpu().detach().numpy())
        results["validation_recon_errors"].append(validation_recon_loss_per_epoch.cpu().detach().numpy())
        results["training_quant_errors"].append(training_quant_error_per_epoch.cpu().detach().numpy())
        results["validation_quant_errors"].append(validation_quant_error_per_epoch.cpu().detach().numpy())
        results["epoch"].append(epoch)

        # print results per epoch
        print('Epoch #', epoch, 'Training Recon Error:',
                training_recon_loss_per_epoch.cpu().detach().numpy(),
                'Validation Recon Error:',
                validation_recon_loss_per_epoch.cpu().detach().numpy(),
                'Training Quantization Error', training_quant_error_per_epoch.cpu().detach().numpy(),
                    'Validation Quantization Error', validation_quant_error_per_epoch.cpu().detach().numpy(),
                    'Perplexity', validation_perplexity_per_epoch.cpu().detach().numpy())
        if epoch % n_epochs == 0:
            if quant_type == 'scq':
                    utils.save_model_and_results(
                    model, results, hyperparameters, f'/SCQ_{dataset}_Model{j}')
            elif quant_type == 'vq':
                    utils.save_model_and_results(
                    model, results, hyperparameters, f'/VQ_{dataset}_Model{j}')
            elif quant_type == 'vqreplace':
                utils.save_model_and_results(
                    model, results, hyperparameters, f'/VQReplace_{dataset}_Model{j}')
            elif quant_type == 'vqaffineopt':
                utils.save_model_and_results(
                    model, results, hyperparameters, f'/VQAffineOpt_{dataset}_Model{j}')
            elif quant_type == 'vqreplaceaffineopt':
                utils.save_model_and_results(
                    model, results, hyperparameters, f'/VQReplaceAffineOpt_{dataset}_Model{j}')
            elif quant_type == 'rv':
                utils.save_model_and_results(
                    model, results, hyperparameters, f'/RV_{dataset}_Model{j}')
            else:
                utils.save_model_and_results(
                    model, results, hyperparameters, f'/Gumbel_{dataset}_Model{j}')
                
            

if __name__ == "__main__":
    trials = 1
    quant_type = 'scq'
    n_epochs = 50

    """
    Set up VQ-VAE model with components defined in ./models/ folder
    """

    n_embeddings = 128
    embedding_dim = 16
    learning_rate_model = 3e-4
    batch_size = args.batch_size
    dataset = 'LFW'
    iterations_scq = 20
    generate_latents = False
    num_samples_latent = None
    n_res_block = args.n_res_block
    n_res_channel = args.n_res_channel
    hidden_channel = args.hidden_channel
    
    # Add to dictionary to init/load model
    hyperparameters = { 
        'in_channel': 3,
        'out_channel': 3,
        'hidden_channel': hidden_channel,
        'n_res_block': n_res_block,
        'n_res_channel': n_res_channel,
        'embed_dim': embedding_dim,
        'n_embed': n_embeddings,
        'decay': 0.99,
        'quant_type': quant_type,
        'iterations_scq': iterations_scq,
        'generate_latents': generate_latents,
        'num_samples_latent': num_samples_latent}

    """
    Load data and define batch data loaders
    """
    print(hyperparameters)
    training_data, validation_data, training_loader, validation_loader = utils.load_data_and_data_loaders(
    dataset, batch_size)

    for trial in range(trials):
        train_loop(trial, n_epochs)

