import sys
sys.path.append('../../')

import argparse
import torch
import gpvae

from scipy.cluster.vq import kmeans2
from experiments.weather.train_japan_weekly import train_japan_weekly, \
    preprocess
from experiments.weather.train_japan_week import train_japan_week

torch.set_default_dtype(torch.float64)

observations = ['PRCP', 'TMAX', 'TMIN', 'TAVG', 'SNWD']


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main(args):
    all_data = preprocess(observations, vars(args))

    ###################
    # GP prior kernel #
    ###################
    kernel = gpvae.kernels.RBFKernel(lengthscale=args.init_lengthscale,
                                     scale=args.init_scale)

    ###############
    # GPVAE model #
    ###############
    # Decoder.
    if args.linear_likelihood:
        # Linear likelihood hyperparameters.
        decoder_args = {'in_dim': args.latent_dim,
                        'out_dim': len(observations),
                        'sigma': .1}
        decoder = gpvae.networks.AffineGaussian(**decoder_args)
        args.pinference_net += '_ll'
    else:
        # DNN likelihood hyperparameters.
        decoder_args = {'in_dim': args.latent_dim,
                        'out_dim': len(observations),
                        'hidden_dims': args.decoder_dims,
                        'sigma': args.sigma,
                        'train_sigma': args.train_sigma,
                        'min_sigma': 1e-3
                        }
        decoder = gpvae.networks.LinearGaussian(**decoder_args)

    # Construct partial inference network.
    if args.pinference_net in ['factornet', 'factornet_ll']:
        # FactorNet hyperparameters.
        encoder_args = {'in_dim': len(observations),
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims[:-1],
                        'initial_sigma': 1.,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.FactorNet(**encoder_args)

    elif args.pinference_net in ['indexnet', 'indexnet_ll']:
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'hidden_dims, middle_dim, ' \
                                                'shared_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        hidden_dims = args.encoder_dims[:middle_dim_idx]
        shared_hidden_dims = args.encoder_dims[middle_dim_idx + 1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # IndexNet hyperparameters.
        encoder_args = {'in_dim': len(observations),
                        'out_dim': args.latent_dim,
                        'hidden_dims': hidden_dims,
                        'shared_hidden_dims': shared_hidden_dims,
                        'middle_dim': middle_dim,
                        'initial_sigma': 1.,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.IndexNet(**encoder_args)

    elif args.pinference_net in ['pointnet', 'pointnet_ll']:
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'first_hidden_dims, ' \
                                                'middle_dim, ' \
                                                'second_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        first_hidden_dims = args.encoder_dims[:middle_dim_idx]
        second__hidden_dims = args.encoder_dims[middle_dim_idx + 1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # PointNet encoder hyperparameters.
        encoder_args = {'out_dim': args.latent_dim,
                        'middle_dim': middle_dim,
                        'first_hidden_dims': first_hidden_dims,
                        'second_hidden_dims': second__hidden_dims,
                        'initial_sigma': 1.,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.PointNet(**encoder_args)

    elif args.pinference_net in ['zeroimputation', 'zeroimputation_ll']:
        # Zero imputation hyperparameters.
        encoder_args = {'in_dim': len(observations),
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims[:-1],
                        'initial_sigma': 1.,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        # Zero imputation GPVAE model.
        encoder = gpvae.networks.LinearGaussian(**encoder_args)

    else:
        raise ValueError('{} is not a partial inference network.'.format(
            args.pinference_net))

    # Construct VAE model.
    if args.model == 'gpvae':
        args.model = 'gpvae_' + args.pinference_net
        # GPVAE model.
        model = gpvae.models.GPVAE(encoder, decoder, args.latent_dim,
                                   kernel, add_jitter=args.add_jitter)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        if args.week == -1:
            metrics = train_japan_weekly(
                model, gpvae.estimators.gpvae_estimators.td_estimator,
                hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator)
        else:
            metrics = train_japan_week(
                model, gpvae.estimators.gpvae_estimators.td_estimator,
                hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator)

    elif args.model == 'sgpvae':
        args.model = 'sgpvae_' + args.pinference_net
        # Initialise inducing points using k-means.
        z = kmeans2(all_data['x_train'][0].numpy(), k=args.num_inducing,
                    minit='points')[0]
        z = torch.tensor(z)

        # Michael's SGP-VAE model.
        model = gpvae.models.SparseGPVAE2(
            encoder, decoder, args.latent_dim, kernel, z,
            add_jitter=args.add_jitter)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        if args.week == -1:
            metrics = train_japan_weekly(
                model, gpvae.estimators.sgpvae_estimators.sa_estimator,
                hyperparameters,
                gpvae.estimators.sgpvae_estimators.elbo_estimator)
        else:
            metrics = train_japan_week(
                model, gpvae.estimators.sgpvae_estimators.sa_estimator,
                hyperparameters,
                gpvae.estimators.sgpvae_estimators.elbo_estimator)

    elif args.model == 'sgpvae_old':
        args.model = 'sgpvae_old_' + args.pinference_net

        # Initialise inducing points using k-means.
        z = kmeans2(all_data['x_train'][0].numpy(), k=args.num_inducing,
                    minit='points')[0]
        z = torch.tensor(z)

        # Sparse GPVAE encoder hyperparameters.
        encoder_args = {'in_dim': len(observations),
                        'out_dim': args.latent_dim,
                        'z': z,
                        'hidden_dims': args.encoder_dims,
                        'k': args.k,
                        'min_sigma': 0.01,
                        'fixed_inducing': args.fixed_inducing,
                        'pinference_net': args.pinference_net
                        }

        encoder = gpvae.networks.SparseNet(**encoder_args)

        # Sparse GPVAE model.
        model = gpvae.models.SparseGPVAE(encoder, decoder, args.latent_dim,
                                         kernel, add_jitter=args.add_jitter)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        if args.week == -1:
            metrics = train_japan_weekly(
                model, gpvae.estimators.gpvae_estimators.vfe_td_estimator,
                hyperparameters,
                gpvae.estimators.gpvae_estimators.vfe_elbo_estimator)
        else:
            metrics = train_japan_week(
                model, gpvae.estimators.gpvae_estimators.vfe_td_estimator,
                hyperparameters,
                gpvae.estimators.gpvae_estimators.vfe_elbo_estimator)

    elif args.model == 'vae':
        args.model = 'vae_' + args.pinference_net
        # VAE model.
        model = gpvae.models.VAE(encoder, decoder, args.latent_dim)

        hyperparameters = vars(args)
        hyperparameters['encoder_args'] = encoder_args
        hyperparameters['decoder_args'] = decoder_args

        # Train model.
        if args.week == -1:
            metrics = train_japan_weekly(
                model, gpvae.estimators.vae_estimators.td_estimator,
                hyperparameters,
                gpvae.estimators.vae_estimators.elbo_estimator)
        else:
            metrics = train_japan_week(
                model, gpvae.estimators.vae_estimators.td_estimator,
                hyperparameters,
                gpvae.estimators.vae_estimators.elbo_estimator)

    else:
        raise ValueError('{} is not a model.'.format(args.model))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # General.
    parser.add_argument('--results_dir', default='./_weekly_results/')
    parser.add_argument('--elev', default=True, type=str2bool)
    parser.add_argument('--evaluate_1980', default=True, type=str2bool)
    parser.add_argument('--evaluate_1981', default=True, type=str2bool)
    parser.add_argument('--save_model', default=True, type=str2bool)
    parser.add_argument('--week', default=-1, type=int)
    parser.add_argument('--experiment', default=1, type=int)

    # Kernel.
    parser.add_argument('--init_lengthscale', default=1., type=float)
    parser.add_argument('--init_scale', default=1., type=float)

    # GPVAE.
    parser.add_argument('--model', default='sgpvae')
    parser.add_argument('--pinference_net', default='indexnet')
    parser.add_argument('--latent_dim', default=3, type=int)
    parser.add_argument('--decoder_dims', default=[20, 20], nargs='+',
                        type=int)
    parser.add_argument('--sigma', default=0.1, type=float)
    parser.add_argument('--train_sigma', default=True, type=str2bool)
    parser.add_argument('--encoder_dims', default=[20, 20, 20, 20, 20],
                        nargs='+', type=int)
    parser.add_argument('--linear_likelihood', default=False, type=str2bool)
    parser.add_argument('--add_jitter', default=True, type=bool)
    parser.add_argument('--num_inducing', default=100, type=int)
    parser.add_argument('--fixed_inducing', default=False, type=str2bool)
    parser.add_argument('--k', default=20, type=int)

    # Training.
    parser.add_argument('--epochs', default=21, type=int)
    parser.add_argument('--cache_freq', default=5, type=int)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--num_samples', default=1, type=int)
    parser.add_argument('--elbo_samples', default=100, type=int)
    parser.add_argument('--test_samples', default=100, type=int)
    parser.add_argument('--batch_size', type=int)

    args = parser.parse_args()
    main(args)
