import sys, pdb, time
sys.path.append('../../')

import argparse
import numpy as np
import torch
import gpvae

from torch.utils.data import DataLoader
from experiments.big_eeg.train_big_eeg import train_big_eeg, \
    train_big_eeg_modified
from data import big_eeg, big_eeg2, big_eeg3

torch.set_default_dtype(torch.float64)


def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main(args):
    ######################
    # Data preprocessing #
    ######################
    # Load data.
    if args.experiment == 1:
        train, test, train_test, eval_test, train_eval_train, eval_eval_train\
            = big_eeg.load()
    elif args.experiment == 2:
        train, test, train_test, eval_test, train_eval_train, eval_eval_train \
            = big_eeg2.load()
    elif args.experiment == 3:
        train, test, train_test, eval_test, train_eval_train, eval_eval_train \
            = big_eeg3.load()
    else:
        raise ValueError('Only experiments 1 or 2 exist.')

    if args.all_observations:
        observations = list(train[0].columns)
    else:
        observations = ['FZ', 'F1', 'F2', 'F3', 'F4', 'F5', 'F6']

    setattr(args, 'observations', observations)

    # Extract data into numpy arrays.
    x = np.array(list(map(lambda x: x.index, train)))
    y = np.array(list(map(lambda x: x[observations].values, train)))
    x_test = np.array(list(map(lambda x: x.index, train_test)))
    y_test = np.array(list(map(lambda x: x[observations].values, train_test)))
    x_eval_train = np.array(list(map(lambda x: x.index, train_eval_train)))
    y_eval_train = np.array(list(map(lambda x: x[observations].values,
                                 train_eval_train)))

    y_dim = len(observations)

    # Normalise observations.
    y_mean, y_std = np.nanmean(y, axis=(0, 1)), np.nanstd(y, axis=(0, 1))
    y = (y - y_mean) / y_std
    y_test = (y_test - y_mean) / y_std
    y_eval_train = (y_eval_train - y_mean) / y_std

    # Convert to tensors.
    x = list(map(lambda x: torch.tensor(x), x))
    y = list(map(lambda x: torch.tensor(x), y))
    x_test = list(map(lambda x: torch.tensor(x), x_test))
    y_test = list(map(lambda x: torch.tensor(x), y_test))
    x_eval_train = list(map(lambda x: torch.tensor(x), x_eval_train))
    y_eval_train = list(map(lambda x: torch.tensor(x), y_eval_train))

    # Set up loaders.
    train_dataset = gpvae.utils.dataset_utils.MetaTupleDataset(
        x, y, contains_nan=True)
    train_loader = DataLoader(
        train_dataset, batch_size=1, shuffle=False)
    test_dataset = gpvae.utils.dataset_utils.MetaTupleDataset(
        x_test, y_test, contains_nan=True)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=1, shuffle=False)
    eval_train_dataset = gpvae.utils.dataset_utils.MetaTupleDataset(
        x_eval_train, y_eval_train, contains_nan=True)
    eval_train_loader = torch.utils.data.DataLoader(
        eval_train_dataset, batch_size=1, shuffle=False)

    ###################
    # GP prior kernel #
    ###################
    kernel = gpvae.kernels.RBFKernel(
        lengthscale=args.init_lengthscale, scale=args.init_scale)

    ###############
    # GPVAE model #
    ###############
    if args.linear_likelihood:
        # Linear likelihood hyperparameters.
        decoder_args = {'in_dim': args.latent_dim,
                        'out_dim': y_dim,
                        'sigma': .1}
        decoder = gpvae.networks.AffineGaussian(**decoder_args)
        args.pinference_net += '_ll'
    else:
        # DNN likelihood hyperparameters.
        decoder_args = {'in_dim': args.latent_dim,
                        'out_dim': y_dim,
                        'hidden_dims': args.decoder_dims,
                        'sigma': args.sigma,
                        'train_sigma': args.train_sigma
                        }
        decoder = gpvae.networks.LinearGaussian(**decoder_args)

    if args.pinference_net in ['factornet', 'factornet_ll']:
        # FactorNet hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims[:-1],
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.FactorNet(**encoder_args)

    elif args.pinference_net in ['indexnet', 'indexnet_ll']:
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'hidden_dims, middle_dim, ' \
                                                'shared_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        hidden_dims = args.encoder_dims[:middle_dim_idx]
        shared_hidden_dims = args.encoder_dims[middle_dim_idx+1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # IndexNet hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': hidden_dims,
                        'shared_hidden_dims': shared_hidden_dims,
                        'middle_dim': middle_dim,
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.IndexNet(**encoder_args)

    elif args.pinference_net in ['pointnet', 'pointnet_ll']:
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'first_hidden_dims, ' \
                                                'middle_dim, ' \
                                                'second_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        first_hidden_dims = args.encoder_dims[:middle_dim_idx]
        second_hidden_dims = args.encoder_dims[middle_dim_idx + 1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # PointNet encoder hyperparameters.
        encoder_args = {'out_dim': args.latent_dim,
                        'middle_dim': middle_dim,
                        'first_hidden_dims': first_hidden_dims,
                        'second_hidden_dims': second_hidden_dims,
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.PointNet(**encoder_args)

    elif args.pinference_net in ['zeroimputation', 'zeroimputation_ll']:
        # Zero imputation hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims[:-1],
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        # Zero imputation GPVAE model.
        encoder = gpvae.networks.LinearGaussian(**encoder_args)

    elif args.pinference_net in ['resnet', 'resnet_ll']:
        # ResNet encoder hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'h_dims': args.encoder_dims[:-1],
                        'rho_dims': [50],
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.ResNet(**encoder_args)

    else:
        raise ValueError('{} is not a partial inference network.'.format(
            args.pinference_net))

    # Construct VAE model.
    if args.model == 'gpvae':
        args.model = 'gpvae_' + args.pinference_net
        # GPVAE model.
        model = gpvae.models.GPVAE(encoder, decoder, args.latent_dim,
                                   kernel, add_jitter=args.add_jitter)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        if args.modified_elbo:
            metrics = train_big_eeg(
                model,
                gpvae.estimators.gpvae_estimators.conditional_td_estimator,
                train_loader, test_loader, eval_test, eval_train_loader,
                eval_eval_train, y_mean, y_std, hyperparameters)
        else:
            metrics = train_big_eeg(
                model, gpvae.estimators.gpvae_estimators.td_estimator,
                train_loader, test_loader, eval_test, eval_train_loader,
                eval_eval_train, y_mean, y_std, hyperparameters)

    elif args.model == 'sgpvae':
        args.model = 'sgpvae_' + args.pinference_net
        z = torch.linspace(
            0, x[0][-1].item(), steps=args.num_inducing).unsqueeze(1)

        # Michael's SGP-VAE model.
        model = gpvae.models.SparseGPVAE2(
            encoder, decoder, args.latent_dim, kernel, z,
            add_jitter=args.add_jitter, fixed_inducing=args.fixed_inducing)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        metrics = train_big_eeg(
            model, gpvae.estimators.sgpvae_estimators.sa_estimator,
            train_loader, test_loader, eval_test, eval_train_loader,
            eval_eval_train, y_mean, y_std, hyperparameters)

    elif args.model == 'sgpvae_old':
        args.model = 'sgpvae_old_' + args.pinference_net
        z = torch.linspace(
            0, x[0][-1].item(), steps=args.num_inducing).unsqueeze(1)

        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'z': z,
                        'hidden_dims': args.encoder_dims[:-1],
                        'k': args.k,
                        'min_sigma': 0.01,
                        'fixed_inducing': args.fixed_inducing,
                        'pinference_net': args.pinference_net
                        }

        encoder = gpvae.networks.SparseNet(**encoder_args)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Sparse GPVAE model, using zero imputation for now.
        model = gpvae.models.SparseGPVAE(encoder, decoder, args.latent_dim,
                                         kernel, add_jitter=args.add_jitter)

        # Train model.
        metrics = train_big_eeg(
            model,
            gpvae.estimators.gpvae_estimators.vfe_td_estimator,
            train_loader, test_loader, eval_test, eval_train_loader,
            eval_eval_train, y_mean, y_std, hyperparameters)

    elif args.model == 'vae':
        args.model = 'vae_' + args.pinference_net
        # VAE model
        model = gpvae.models.VAE(encoder, decoder, args.latent_dim)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        if args.modified_elbo:
            metrics = train_big_eeg_modified(
                model,
                gpvae.estimators.vae_estimators.conditional_td_estimator,
                train_loader, test_loader, eval_test, eval_train_loader,
                eval_eval_train, y_mean, y_std, hyperparameters)
        else:
            metrics = train_big_eeg(
                model,
                gpvae.estimators.vae_estimators.td_estimator,
                train_loader, test_loader, eval_test, eval_train_loader,
                eval_eval_train, y_mean, y_std, hyperparameters)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # General.
    parser.add_argument('--results_dir', default='./_results/')
    parser.add_argument('--experiment', default=1, type=int)

    # Kernel.
    parser.add_argument('--init_lengthscale', default=0.05, type=float)
    parser.add_argument('--init_scale', default=1., type=float)

    # GPVAE.
    parser.add_argument('--model', default='gpvae')
    parser.add_argument('--pinference_net', default='indexnet')
    parser.add_argument('--latent_dim', default=10, type=int)
    parser.add_argument('--auxiliary_dim', default=1, type=int)
    parser.add_argument('--decoder_dims', default=[50, 50], nargs='+',
                        type=int)
    parser.add_argument('--sigma', default=0.1, type=float)
    parser.add_argument('--train_sigma', default=True, type=str2bool)
    parser.add_argument('--encoder_dims', default=[50, 50, 50],
                        nargs='+', type=int)
    parser.add_argument('--inducing_spacing', default=0.025, type=float)
    parser.add_argument('--k', default=1, type=int)
    parser.add_argument('--num_inducing', default=64, type=int)
    parser.add_argument('--fixed_inducing', default=False, type=str2bool)
    parser.add_argument('--add_jitter', default=True, type=str2bool)
    parser.add_argument('--linear_likelihood', default=False, type=str2bool)

    # Training.
    parser.add_argument('--epochs', default=50, type=int)
    parser.add_argument('--cache_freq', default=5, type=int)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--num_samples', default=1, type=int)
    parser.add_argument('--elbo_samples', default=100, type=int)
    parser.add_argument('--test_samples', default=100, type=int)
    parser.add_argument('--modified_elbo', default=False, type=str2bool)
    parser.add_argument('--p', default=0.25, type=float)
    parser.add_argument('--all_observations', default=True, type=str2bool)

    args = parser.parse_args()
    main(args)
