import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['pdf.fonttype'] = 42
plt.rcParams.update({'font.size': 12})
import numpy as np
import sys
import pytorch_lightning as pl

from models.vae import VAE


class DataModule(pl.LightningDataModule):
    def __init__(self, train_loader, validation_loader, test_loader, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.train_loader = train_loader
        self.validation_loader = validation_loader
        self.test_loader = test_loader

    def train_dataloader(self):
        return self.train_loader
    
    def val_dataloader(self):
        return self.validation_loader
    
    def test_dataloader(self):
        return self.test_loader




def load_cifar():
    train = datasets.CIFAR10(root="data", train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize(
                                     (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                             ]))

    val = datasets.CIFAR10(root="data", train=False, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize(
                                   (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                           ]))
    return train, val

def load_mnist():
    train = datasets.MNIST(root="data", train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.PILToTensor()]))

    val = datasets.MNIST(root="data", train=False, download=True,
                           transform=transforms.Compose([
                               transforms.PILToTensor()]))
    return train, val


def load_LFW():
    train = datasets.GTSRB(root="data", split='train', download=True,
                             transform=transforms.Compose([
                                transforms.CenterCrop(48),
                                 transforms.ToTensor(),
                                 transforms.Normalize(
                                   (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                             ]))

    val = datasets.GTSRB(root="data", split='test', download=True,
                           transform=transforms.Compose([
                            transforms.CenterCrop(48),
                               transforms.ToTensor(),
                               transforms.Normalize(
                                   (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                           ]))
    return train, val

def load_lsunchurch():
    train = datasets.LSUN(root="data", classes=['church_outdoor_train'],
                             transform=transforms.Compose([
                                transforms.CenterCrop(256),
                                 transforms.ToTensor(),
                                 transforms.Normalize(
                                     (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                             ]))

    val = datasets.LSUN(root="data", classes=['church_outdoor_val'],
                           transform=transforms.Compose([
                            transforms.CenterCrop(256),
                               transforms.ToTensor(),
                               transforms.Normalize(
                                    (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                           ]))
    return train, val



def data_loaders(train_data, val_data, batch_size, sampler=None):

    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              pin_memory=True, shuffle=True, sampler=sampler, num_workers=0)
    val_loader = DataLoader(val_data,
                            batch_size=batch_size,
                            shuffle=False,
                            pin_memory=True, num_workers=0)
    return train_loader, val_loader


def load_data_and_data_loaders(dataset, batch_size):
    if dataset == 'CIFAR10':
        training_data, validation_data = load_cifar()
        training_loader, validation_loader = data_loaders(
            training_data, validation_data, batch_size)
        
    elif dataset == 'LFW':
        training_data, validation_data = load_LFW()
        training_loader, validation_loader = data_loaders(
            training_data, validation_data, batch_size)
        
    elif dataset == 'MNIST':
        training_data, validation_data = load_mnist()
        training_loader, validation_loader = data_loaders(
            training_data, validation_data, batch_size)
        
    elif dataset == 'LSUNChurch':
        training_data, validation_data = load_lsunchurch()
        sampler = torch.utils.data.RandomSampler(training_data, replacement=False, num_samples=50000)
        training_loader, validation_loader = data_loaders(
            training_data, validation_data, batch_size, sampler=None)
    else:
        raise ValueError(
            'Invalid dataset: only CIFAR10, LFW, MNIST, LSUN-Church datasets are supported.')

    return training_data, validation_data, training_loader, validation_loader


def readable_timestamp():
    return time.ctime().replace('  ', ' ').replace(
        ' ', '_').replace(':', '_').lower()


def save_model_and_results(model, results, hyperparameters, filename):
    SAVE_MODEL_PATH = os.getcwd()

    results_to_save = {
        'model': model.state_dict(),
        'results': results,
        'hyperparameters': hyperparameters
    }
    torch.save(results_to_save,
               SAVE_MODEL_PATH + filename + '.pth')
    
def load_vae_model(PATH):
    results = torch.load(PATH, torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model = VAE(**results['hyperparameters'])
    model.load_state_dict(results['model'])
    return model

def load_vae_results(PATH):
    results = torch.load(PATH, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    return results['results']

def load_all_results(PATH):
    import glob
    ls = glob.glob(PATH)
    recon_error_scq_train = []
    recon_error_scq_val = []
    recon_error_vq_train = []
    recon_error_vq_val = []

    quant_error_scq_train = []
    quant_error_scq_val = []
    quant_error_vq_train = []
    quant_error_vq_val = []

    recon_error_g_train = []
    recon_error_g_val = []
    quant_error_g_train = []
    quant_error_g_val = []

    recon_error_rq_train = []
    recon_error_rq_val = []
    quant_error_rq_train = []
    quant_error_rq_val = []

    recon_error_vq2_train = []
    recon_error_vq2_val = []
    quant_error_vq2_train = []
    quant_error_vq2_val = []
    
    for path in ls:
        if "SCQ" in path:
            SCQ_results = load_vae_results(path)
            recon_error_scq_train.append(SCQ_results['training_recon_errors'])
            recon_error_scq_val.append(SCQ_results['validation_recon_errors'])
            quant_error_scq_train.append(SCQ_results['training_quant_errors'])
            quant_error_scq_val.append(SCQ_results['validation_quant_errors'])
        elif "VQ_" in path:
            VQ_results = load_vae_results(path)
            recon_error_vq_train.append(VQ_results['training_recon_errors'])
            recon_error_vq_val.append(VQ_results['validation_recon_errors'])
            quant_error_vq_train.append(VQ_results['training_quant_errors'])
            quant_error_vq_val.append(VQ_results['validation_quant_errors'])
        elif "RQ" in path:
            VQ_results = load_vae_results(path)
            recon_error_rq_train.append(VQ_results['training_recon_errors'])
            recon_error_rq_val.append(VQ_results['validation_recon_errors'])
            quant_error_rq_train.append(VQ_results['training_quant_errors'])
            quant_error_rq_val.append(VQ_results['validation_quant_errors'])
        elif "Opt_" in path:
            VQ_results = load_vae_results(path)
            recon_error_vq2_train.append(VQ_results['training_recon_errors'])
            recon_error_vq2_val.append(VQ_results['validation_recon_errors'])
            quant_error_vq2_train.append(VQ_results['training_quant_errors'])
            quant_error_vq2_val.append(VQ_results['validation_quant_errors'])
        else:
            Gumbel_results = load_vae_results(path)
            recon_error_g_train.append(Gumbel_results['training_recon_errors'])
            recon_error_g_val.append(Gumbel_results['validation_recon_errors'])
            quant_error_g_train.append(Gumbel_results['training_quant_errors'])
            quant_error_g_val.append(Gumbel_results['validation_quant_errors'])

    return recon_error_scq_train, recon_error_scq_val, recon_error_vq_train, recon_error_vq_val, recon_error_g_train, recon_error_g_val, recon_error_rq_train, recon_error_rq_val, recon_error_vq2_train, recon_error_vq2_val, quant_error_scq_train, quant_error_scq_val, quant_error_vq_train, quant_error_vq_val, quant_error_g_train, quant_error_g_val, quant_error_rq_train, quant_error_rq_val, quant_error_vq2_train, quant_error_vq2_val

def plot_results(title, error, validation_results_scq, validation_results_rq, validation_results_vq2):
    plt.figure()
    print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
    print(np.array(validation_results_scq))
    v = np.array(validation_results_scq)
    means, mins, maxes = np.mean(v, axis=0), np.amin(v, axis=0), np.amax(v, axis=0)
    plt.plot(range(len(means)), means, linewidth=2, markersize=12, color='blue', label='SCQ')
    plt.fill_between(range(len(means)), mins, maxes, alpha=0.5)
    v = np.array(validation_results_rq)
    means, mins, maxes = np.mean(v, axis=0), np.amin(v, axis=0), np.amax(v, axis=0)
    plt.plot(range(len(means)), means, linewidth=2, markersize=12, color='red', label='RQ')
    plt.fill_between(range(len(means)), mins, maxes, alpha=0.5)
    v = np.array(validation_results_vq2)
    means, mins, maxes = np.mean(v, axis=0), np.amin(v, axis=0), np.amax(v, axis=0)
    plt.plot(range(len(means)), means, linewidth=2, markersize=12, color='black', label='VQ + Rep + Aff + Opt')
    plt.fill_between(range(len(means)), mins, maxes, alpha=0.5, color='gray')
    plt.title(title + ': ' + error + 'Error')
    plt.xlim(0, 49)
    plt.ylim(0, 0.025)
    plt.ylabel('MSE')
    plt.xlabel('Epoch')
    plt.legend()
    plt.grid()
    plt.savefig(title + '_' + error +'.pdf')
