import sys
sys.path.append('../../')

import argparse
import numpy as np
import torch
import gpvae
import data

from torch.utils.data import DataLoader
from experiments.eeg.train_eeg import train_eeg

torch.set_default_dtype(torch.float64)


def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main(args):
    ######################
    # Data preprocessing #
    ######################
    # Load from Wessel's brilliant wbml package.
    _, train, test = data.eeg.load()

    # Extract data into numpy arrays.
    x = np.array(train.index)
    y = np.array(train)
    y_dim = y.shape[1]
    decoder_scale = 1. / (np.sum(~np.isnan(y)) / (y.shape[0] * y.shape[1]))
    setattr(args, 'decoder_scale', decoder_scale)

    # Normalise observations.
    y_mean, y_std = np.nanmean(y, axis=0), np.nanstd(y, axis=0)
    y = (y - y_mean) / y_std

    # Convert to tensors.
    x = torch.tensor(x)
    y = torch.tensor(y)

    # Set up loaders.
    dataset = gpvae.utils.dataset_utils.TupleDataset(x, y, contains_nan=True)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # For models using the VFE approximation.
    vfe_loader = DataLoader(dataset, batch_size=len(x))

    ###################
    # GP prior kernel #
    ###################
    if args.periodic_kernel:
        k1 = gpvae.kernels.RBFKernel(
            lengthscale=args.init_lengthscale, scale=args.init_scale/2)
        k2 = gpvae.kernels.PeriodicKernel(
            lengthscale=args.init_lengthscale, period=args.init_period,
            scale=args.init_scale/2)
        kernel = gpvae.kernels.AdditiveKernel(k1, k2)
    else:
        kernel = gpvae.kernels.RBFKernel(
            lengthscale=args.init_lengthscale, scale=args.init_scale)

    ###############
    # GPVAE model #
    ###############
    if args.linear_likelihood:
        # Linear likelihood hyperparameters.
        decoder_args = {'in_dim': args.latent_dim,
                        'out_dim': y_dim,
                        'sigma': .1}
        decoder = gpvae.networks.AffineGaussian(**decoder_args)
        args.pinference_net += '_ll'
    else:
        # DNN likelihood hyperparameters.
        decoder_args = {'in_dim': args.latent_dim,
                        'out_dim': y_dim,
                        'hidden_dims': args.decoder_dims,
                        'sigma': args.sigma,
                        'train_sigma': args.train_sigma
                        }
        decoder = gpvae.networks.LinearGaussian(**decoder_args)

    # Construct partial inference network.
    if args.pinference_net in ['factornet', 'factornet_ll']:
        # FactorNet hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims[:-1],
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.FactorNet(**encoder_args)

    elif args.pinference_net in ['indexnet', 'indexnet_ll']:
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'hidden_dims, middle_dim, ' \
                                                'shared_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        hidden_dims = args.encoder_dims[:middle_dim_idx]
        shared_hidden_dims = args.encoder_dims[middle_dim_idx+1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # IndexNet hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': hidden_dims,
                        'shared_hidden_dims': shared_hidden_dims,
                        'middle_dim': middle_dim,
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.IndexNet(**encoder_args)

    elif args.pinference_net in ['pointnet', 'pointnet_ll']:
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'first_hidden_dims, ' \
                                                'middle_dim, ' \
                                                'second_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        first_hidden_dims = args.encoder_dims[:middle_dim_idx]
        second__hidden_dims = args.encoder_dims[middle_dim_idx + 1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # PointNet encoder hyperparameters.
        encoder_args = {'out_dim': args.latent_dim,
                        'middle_dim': middle_dim,
                        'first_hidden_dims': first_hidden_dims,
                        'second_hidden_dims': second__hidden_dims,
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.PointNet(**encoder_args)

    elif args.pinference_net in ['zeroimputation', 'zeroimputation_ll']:
        # Zero imputation hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims[:2],
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        # Zero imputation GPVAE model.
        encoder = gpvae.networks.LinearGaussian(**encoder_args)

    else:
        raise ValueError('{} is not a partial inference network.'.format(
            args.pinference_net))

    # Construct VAE model.
    if args.model == 'gpvae':
        args.model = 'gpvae_' + args.pinference_net

        # GPVAE model.
        model = gpvae.models.GPVAE(encoder, decoder, args.latent_dim,
                                   kernel, add_jitter=args.add_jitter)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        if args.model in ['indexnet', 'indexnet_ll']:
            metrics = train_eeg(
                model, gpvae.estimators.gpvae_estimators.analytical_estimator,
                loader, hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator)
        else:
            metrics = train_eeg(
                model, gpvae.estimators.gpvae_estimators.td_estimator,
                loader, hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator)

    if args.model == 'gppvae':
        args.model = 'gppvae_' + args.pinference_net

        # GPVAE model.
        model = gpvae.models.GPPVAE(encoder, decoder, args.latent_dim,
                                    kernel, add_jitter=args.add_jitter)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        metrics = train_eeg(
            model, gpvae.estimators.gppvae_estimators.analytical_estimator,
            loader, hyperparameters,
            gpvae.estimators.gppvae_estimators.elbo_estimator)

    elif args.model == 'sgpvae':
        args.model = 'sgpvae_' + args.pinference_net
        z = torch.linspace(
            0, x[-1].item(), steps=args.num_inducing).unsqueeze(1)

        # Michael's SGP-VAE model.
        model = gpvae.models.SparseGPVAE2(
            encoder, decoder, args.latent_dim, kernel, z,
            add_jitter=args.add_jitter, fixed_inducing=args.fixed_inducing)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        metrics = train_eeg(
            model, gpvae.estimators.sgpvae_estimators.sa_estimator,
            loader, hyperparameters,
            gpvae.estimators.sgpvae_estimators.elbo_estimator)

    elif args.model == 'sgpvae_old':
        args.model = 'sgpvae_old_' + args.pinference_net
        z = torch.linspace(
            0, x[-1].item(), steps=args.num_inducing).unsqueeze(1)

        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'z': z,
                        'hidden_dims': args.encoder_dims[:2],
                        'k': args.k,
                        'min_sigma': 0.01,
                        'fixed_inducing': args.fixed_inducing,
                        'pinference_net': args.pinference_net
                        }

        encoder = gpvae.networks.SparseNet(**encoder_args)

        # Sparse GPVAE model, using zero imputation for now.
        model = gpvae.models.SparseGPVAE(encoder, decoder, args.latent_dim,
                                         kernel, add_jitter=args.add_jitter)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        metrics = train_eeg(
            model, gpvae.estimators.gpvae_estimators.vfe_td_estimator,
            loader, hyperparameters,
            gpvae.estimators.gpvae_estimators.vfe_elbo_estimator)

    elif args.model == 'vae':
        args.model = 'vae_' + args.pinference_net
        # VAE model.
        model = gpvae.models.VAE(encoder, decoder, args.latent_dim)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        metrics = train_eeg(
            model, gpvae.estimators.vae_estimators.analytical_estimator,
            loader, hyperparameters,
            gpvae.estimators.vae_estimators.elbo_estimator
        )

    else:
        raise ValueError('{} is not a model.'.format(args.model))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # General.
    parser.add_argument('--results_dir', default='./_results/')
    parser.add_argument('--gpvae', default=True, type=str2bool)

    # Kernel.
    parser.add_argument('--init_lengthscale', default=0.05, type=float)
    parser.add_argument('--init_scale', default=1., type=float)
    parser.add_argument('--init_period', default=.1, type=float)
    parser.add_argument('--periodic_kernel', default=False, type=str2bool)

    # GPVAE.
    parser.add_argument('--model', default='gpvae')
    parser.add_argument('--pinference_net', default='indexnet', type=str)
    parser.add_argument('--latent_dim', default=3, type=int)
    parser.add_argument('--auxiliary_dim', default=1, type=int)
    parser.add_argument('--decoder_dims', default=[20], nargs='+',
                        type=int)
    parser.add_argument('--sigma', default=0.1, type=float)
    parser.add_argument('--train_sigma', default=True, type=str2bool)
    parser.add_argument('--encoder_dims', default=[20, 20, 20],
                        nargs='+', type=int)
    parser.add_argument('--inducing_spacing', default=0.025, type=float)
    parser.add_argument('--k', default=1, type=int)
    parser.add_argument('--num_inducing', default=64, type=int)
    parser.add_argument('--add_jitter', default=True, type=str2bool)
    parser.add_argument('--linear_likelihood', default=False, type=str2bool)

    # Training.
    parser.add_argument('--epochs', default=2000, type=int)
    parser.add_argument('--cache_freq', default=100, type=int)
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--num_samples', default=1, type=int)
    parser.add_argument('--elbo_samples', default=100, type=int)
    parser.add_argument('--test_samples', default=100, type=int)

    args = parser.parse_args()
    main(args)
