import sys
sys.path.append('../../')

import argparse
import numpy as np
import torch
import gpvae
import data
import gpytorch

from torch.utils.data import DataLoader
from scipy.cluster.vq import kmeans2
from experiments.jura.train_conditional_jura import train_jura
torch.set_default_dtype(torch.float64)


def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main(args):
    ######################
    # Data preprocessing #
    ######################
    # Load from Wessel's brilliant wbml package.
    train, test = data.jura.load()

    # Extract data into numpy arrays.
    x = [[i, j] for (i, j) in train.index]
    x = np.array(x)
    y = np.array(train)
    y_dim = y.shape[1]
    decoder_scale = 1. / (np.sum(~np.isnan(y)) / (y.shape[0] * y.shape[1]))
    setattr(args, 'decoder_scale', decoder_scale)

    # Apply log transform.
    if args.log_transform:
        y = np.log(y)

    # Normalise observations and inputs.
    y_mean, y_std = np.nanmean(y, axis=0), np.nanstd(y, axis=0)
    y = (y - y_mean) / y_std

    # Convert to tensors.
    x = torch.tensor(x)
    y = torch.tensor(y)

    # Set up loaders.
    dataset = gpvae.utils.dataset_utils.TupleDataset(x, y, contains_nan=True)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # For models using the VFE approximation.
    vfe_loader = DataLoader(dataset, batch_size=len(x))

    ###################
    # GP prior kernel #
    ###################
    if args.gpytorch_kernel:
        kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(
            lengthscale=args.init_lengthscale))
    else:
        kernel = gpvae.kernels.RBFKernel(lengthscale=args.init_lengthscale,
                                         scale=args.init_scale)

    ###############
    # GPVAE model #
    ###############
    decoder_args = {'in_dim': args.latent_dim,
                    'out_dim': y_dim,
                    'hidden_dims': args.decoder_dims,
                    'sigma': args.sigma,
                    'train_sigma': args.train_sigma
                    }

    if args.model == 'pog':
        # PoG hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims,
                        'initial_sigma': args.init_encoder_sigma,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        encoder = gpvae.networks.FactorNet(**encoder_args)
        decoder = gpvae.networks.LinearGaussian(**decoder_args)

        if args.gpvae:
            # PoG GPVAE model.
            model = gpvae.models.GPVAE(encoder, decoder, args.latent_dim,
                                       kernel, args.add_jitter)

            # Train model.
            metrics = train_jura(
                model,
                gpvae.estimators.gpvae_estimators.conditional_td_estimator,
                loader, hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator,
                log_transform=args.log_transform, p=args.p)
        else:
            hyperparameters['model'] = 'pog_vae'

            # PoG VAE model.
            model = gpvae.models.VAE(encoder, decoder, args.latent_dim)

            # Train model.
            metrics = train_jura(
                model, gpvae.estimators.vae_estimators.td_estimator,
                loader, hyperparameters,
                gpvae.estimators.vae_estimators.elbo_estimator,
                log_transform=args.log_transform)

    elif args.model == 'indexnet':
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'hidden_dims, middle_dim, ' \
                                                'shared_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        hidden_dims = args.encoder_dims[:middle_dim_idx]
        shared_hidden_dims = args.encoder_dims[middle_dim_idx+1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # IndexNet hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': hidden_dims,
                        'shared_hidden_dims': shared_hidden_dims,
                        'middle_dim': middle_dim,
                        'initial_sigma': args.init_encoder_sigma,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        encoder = gpvae.networks.IndexNet(**encoder_args)
        decoder = gpvae.networks.LinearGaussian(**decoder_args)

        if args.gpvae:
            # IndexNet GPVAE model.
            model = gpvae.models.GPVAE(encoder, decoder, args.latent_dim,
                                       kernel, args.add_jitter)

            # Train model.
            metrics = train_jura(
                model,
                gpvae.estimators.gpvae_estimators.conditional_td_estimator,
                loader, hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator,
                log_transform=args.log_transform, p=args.p)
        else:
            hyperparameters['model'] = 'indexnet_vae'

            # IndexNet VAE model.
            model = gpvae.models.VAE(encoder, decoder, args.latent_dim)

            # Train model.
            metrics = train_jura(
                model, gpvae.estimators.vae_estimators.td_estimator,
                loader, hyperparameters,
                gpvae.estimators.vae_estimators.elbo_estimator,
                log_transform=args.log_transform)

    elif args.model == 'pointnet':
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'first_hidden_dims, ' \
                                                'middle_dim, ' \
                                                'second_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        first_hidden_dims = args.encoder_dims[:middle_dim_idx]
        second__hidden_dims = args.encoder_dims[middle_dim_idx + 1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # PointNet hyperparameters.
        encoder_args = {'out_dim': args.latent_dim,
                        'middle_dim': middle_dim,
                        'first_hidden_dims': first_hidden_dims,
                        'second_hidden_dims': second__hidden_dims,
                        'initial_sigma': args.init_encoder_sigma,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        encoder = gpvae.networks.PointNet(**encoder_args)
        decoder = gpvae.networks.LinearGaussian(**decoder_args)

        if args.gpvae:
            # PointNet GPVAE model.
            model = gpvae.models.GPVAE(encoder, decoder, args.latent_dim,
                                       kernel, args.add_jitter)

            # Train model.
            metrics = train_jura(
                model,
                gpvae.estimators.gpvae_estimators.conditional_td_estimator,
                loader, hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator,
                log_transform=args.log_transform, p=args.p)
        else:
            hyperparameters['model'] = 'pointnet_vae'

            # PointNet VAE model.
            model = gpvae.models.VAE(encoder, decoder, args.latent_dim)

            # Train model.
            metrics = train_jura(
                model, gpvae.estimators.vae_estimators.td_estimator,
                loader, hyperparameters,
                gpvae.estimators.vae_estimators.elbo_estimator,
                log_transform=args.log_transform)

    elif args.model == 'zeroimputation':
        # Zero imputation hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims,
                        'initial_sigma': args.init_encoder_sigma,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Zero imputation GPVAE model.
        encoder = gpvae.networks.LinearGaussian(**encoder_args)
        decoder = gpvae.networks.LinearGaussian(**decoder_args)

        if args.gpvae:
            # Zero imputation GPVAE model.
            model = gpvae.models.GPVAE(encoder, decoder, args.latent_dim,
                                       kernel, add_jitter=args.add_jitter)

            # Train model.
            metrics = train_jura(
                model,
                gpvae.estimators.gpvae_estimators.conditional_td_estimator,
                loader, hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator,
                log_transform=args.log_transform, p=args.p)
        else:
            hyperparameters['model'] = 'zeroimputation_vae'

            # Zero imputation VAE model.
            model = gpvae.models.VAE(encoder, decoder, args.latent_dim)

            # Train model.
            metrics = train_jura(
                model, gpvae.estimators.gpvae_estimators.td_estimator,
                loader, hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator,
                log_transform=args.log_transform)

    elif args.model == 'hvi':
        middle_dim_idx = len(args.encoder_dims) // 2
        hidden_dims = args.encoder_dims[:middle_dim_idx]
        shared_hidden_dims = args.encoder_dims[middle_dim_idx+1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # HVI GPVAE encoder/decoder hyperparameters.
        latent_encoder_args = {'in_dims': [args.auxiliary_dim, y_dim],
                               'out_dim': args.latent_dim,
                               'hidden_dims': args.encoder_dims,
                               'initial_sigma': 1.,
                               'initial_mu': 0.,
                               'contains_nans': [False, True],
                               'min_sigma': 0.001
                               }
        auxiliary_encoder_args = {'in_dim': y_dim,
                                  'out_dim': args.auxiliary_dim,
                                  'middle_dim': middle_dim,
                                  'hidden_dims': hidden_dims,
                                  'shared_hidden_dims': shared_hidden_dims,
                                  'initial_sigma': .1,
                                  'initial_mu': 0.,
                                  'min_sigma': 0.001
                                  }
        auxiliary_decoder_args = {'in_dims': [y_dim, args.latent_dim],
                                  'out_dim': args.auxiliary_dim,
                                  'hidden_dims': args.decoder_dims,
                                  'initial_sigma': 1.,
                                  'initial_mu': 0.,
                                  'contains_nans': [True, False]
                                  }

        hyperparmaeters = vars(args)
        hyperparmaeters['latent_encoder_args'] = latent_encoder_args
        hyperparmaeters['latent_decoder_args'] = decoder_args
        hyperparmaeters['auxiliary_encoder_args'] = auxiliary_encoder_args
        hyperparmaeters['auxiliary_decoder_args'] = auxiliary_decoder_args

        # HVI GPVAE model.
        latent_encoder = gpvae.networks.MultiInputLinearGaussian(
            **latent_encoder_args)
        latent_decoder = gpvae.networks.LinearGaussian(**decoder_args)
        auxiliary_encoder = gpvae.networks.IndexNet(
            **auxiliary_encoder_args)
        auxiliary_decoder = gpvae.networks.MultiInputLinearGaussian(
            **auxiliary_decoder_args)
        model = gpvae.models.HVIGPVAE(
            latent_encoder, latent_decoder, auxiliary_encoder,
            auxiliary_decoder, args.latent_dim, args.auxiliary_dim, kernel,
            args.add_jitter)

        # Train model.
        metrics = train_jura(
            model, gpvae.estimators.hvi_gpvae_estimators.td_estimator,
            loader, vars(args),
            gpvae.estimators.hvi_gpvae_estimators.elbo_estimator,
            log_transform=args.log_transform)

    elif args.model == 'tvfe':
        # Titsias' VFE encoder.
        encoder_args = {'out_dim': args.latent_dim}

        hyperparmaeters = vars(args)
        hyperparmaeters['encoder_args'] = encoder_args
        hyperparmaeters['decoder_args'] = decoder_args

        # VFE inducing point locations.
        z = kmeans2(x.numpy(), args.num_inducing, minit='points')[0]
        z = torch.tensor(z).unsqueeze(1)

        # Titsias' VFE GPVAE model.
        encoder = gpvae.networks.mf_networks.MeanFieldSparseNet(
            z, **encoder_args)
        decoder = gpvae.networks.LinearGaussian(**decoder_args)
        model = gpvae.models.SparseGPVAE(encoder, decoder, args.latent_dim,
                                      kernel, args.add_jitter)

        # Train model.
        metrics = train_jura(
            model, gpvae.estimators.gpvae_estimators.vfe_analytical_estimator,
            vfe_loader, vars(args),
            gpvae.estimators.gpvae_estimators.vfe_elbo_estimator,
            log_transform=args.log_transform)

    elif args.model == 'rvfe':
        # Rich's VFE encoder.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims,
                        'inducing_spacing': args.inducing_spacing,
                        'k': args.k,
                        'contains_nan': True,
                        'min_sigma': 0.001
                        }

        hyperparmaeters = vars(args)
        hyperparmaeters['encoder_args'] = encoder_args
        hyperparmaeters['decoder_args'] = decoder_args

        # Rich's VFE encoder.
        encoder = gpvae.networks.FixedSparseNet(**encoder_args)
        decoder = gpvae.networks.LinearGaussian(**decoder_args)
        model = gpvae.models.SparseGPVAE(encoder, decoder, args.latent_dim,
                                      kernel, args.add_jitter)

        # Train model.
        metrics = train_jura(
            model, gpvae.estimators.gpvae_estimators.vfe_analytical_estimator,
            vfe_loader, vars(args),
            gpvae.estimators.gpvae_estimators.vfe_elbo_estimator,
            log_transform=args.log_transform)

    elif args.model == 'vae':
        # Vanilla VAE.
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'hidden_dims, middle_dim, ' \
                                                'shared_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        hidden_dims = args.encoder_dims[:middle_dim_idx]
        shared_hidden_dims = args.encoder_dims[middle_dim_idx + 1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # Jonny's encoder hyperparameters.
        encoder_args = {'in_dim': y_dim,
                        'out_dim': args.latent_dim,
                        'hidden_dims': hidden_dims,
                        'shared_hidden_dims': shared_hidden_dims,
                        'middle_dim': middle_dim,
                        'initial_sigma': 1.,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        hyperparmaeters = vars(args)
        hyperparmaeters['encoder_args'] = encoder_args
        hyperparmaeters['decoder_args'] = decoder_args

        # VAE model.
        encoder = gpvae.networks.IndexNet(**encoder_args)
        decoder = gpvae.networks.LinearGaussian(**decoder_args)
        model = gpvae.models.VAE(encoder, decoder, args.latent_dim)

        # Train model.
        metrics = train_jura(
            model, gpvae.estimators.vae_estimators.td_estimator,
            loader, vars(args),
            gpvae.estimators.vae_estimators.elbo_estimator,
            log_transform=args.log_transform)
    else:
        raise ValueError('{} is not a model.'.format(args.model))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # General.
    parser.add_argument('--results_dir', default='./_results_conditional/')
    parser.add_argument('--gpvae', default=True, type=str2bool)

    # Kernel.
    parser.add_argument('--init_lengthscale', default=1., type=float)
    parser.add_argument('--init_scale', default=1., type=float)
    parser.add_argument('--init_period', default=1., type=float)
    parser.add_argument('--gpytorch_kernel', default=False, type=str2bool)

    # GPVAE.
    parser.add_argument('--model', default='indexnet')
    parser.add_argument('--latent_dim', default=2, type=int)
    parser.add_argument('--auxiliary_dim', default=1, type=int)
    parser.add_argument('--decoder_dims', default=[20, 20], nargs='+',
                        type=int)
    parser.add_argument('--sigma', default=0.1, type=float)
    parser.add_argument('--train_sigma', default=True, type=str2bool)
    parser.add_argument('--encoder_dims', default=[20, 20, 20],
                        nargs='+', type=int)
    parser.add_argument('--inducing_spacing', default=0.025, type=float)
    parser.add_argument('--k', default=5, type=int)
    parser.add_argument('--num_inducing', default=100, type=int)
    parser.add_argument('--add_jitter', default=True, type=str2bool)
    parser.add_argument('--init_encoder_sigma', default=.1, type=float)

    # Training.
    parser.add_argument('--epochs', default=3000, type=int)
    parser.add_argument('--cache_freq', default=100, type=int)
    parser.add_argument('--batch_size', default=100, type=int)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--num_samples', default=1, type=int)
    parser.add_argument('--elbo_samples', default=100, type=int)
    parser.add_argument('--test_samples', default=100, type=int)
    parser.add_argument('--log_transform', default=False, type=str2bool)
    parser.add_argument('--p', default=0.1, type=float)

    args = parser.parse_args()
    main(args)
