import sys
sys.path.append('../../')

import argparse
import torch
import gpvae
import pdb

from experiments.weather.train_japan_tridaily import train_japan_tridaily

torch.set_default_dtype(torch.float64)


def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main(args):
    if not args.prcp:
        if not args.snwd:
            args.observations = ['TMAX', 'TMIN', 'TAVG']
        else:
            args.observations = ['PRCP', 'TMAX', 'TMIN', 'TAVG']
    elif not args.snwd:
        args.observations = ['TMAX', 'TMIN', 'TAVG', 'SNWD']
    else:
        args.observations = ['PRCP', 'TMAX', 'TMIN', 'TAVG', 'SNWD']

    args.inputs = ['lat', 'lon', 'elev', 'time']

    ###################
    # GP prior kernel #
    ###################
    kernel = gpvae.kernels.RBFKernel(lengthscale=args.init_lengthscale,
                                     scale=args.init_scale)

    ###############
    # GPVAE model #
    ###############
    # Decoder.
    if args.linear_likelihood:
        # Linear likelihood hyperparameters.
        decoder_args = {'in_dim': args.latent_dim,
                        'out_dim': len(args.observations),
                        'sigma': .1}
        decoder = gpvae.networks.AffineGaussian(**decoder_args)
        args.model += '_ll'
    else:
        if args.complex_decoder:
            args.decoder_dims = [20, 20, 20, 20]

        # DNN likelihood hyperparameters.
        decoder_args = {'in_dim': args.latent_dim,
                        'out_dim': len(args.observations),
                        'hidden_dims': args.decoder_dims,
                        'sigma': args.sigma,
                        'train_sigma': args.train_sigma,
                        'min_sigma': 1e-3
                        }
        decoder = gpvae.networks.LinearGaussian(**decoder_args)

    if args.model in ['factornet', 'factornet_ll']:
        # FactorNet hyperparameters.
        encoder_args = {'in_dim': len(args.observations),
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims,
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.FactorNet(**encoder_args)

    elif args.model in ['indexnet', 'indexnet_ll']:
        # Encoder.
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'hidden_dims, middle_dim, ' \
                                                'shared_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        hidden_dims = args.encoder_dims[:middle_dim_idx]
        shared_hidden_dims = args.encoder_dims[middle_dim_idx + 1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # IndexNet hyperparameters.
        encoder_args = {'in_dim': len(args.observations),
                        'out_dim': args.latent_dim,
                        'hidden_dims': hidden_dims,
                        'shared_hidden_dims': shared_hidden_dims,
                        'middle_dim': middle_dim,
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.IndexNet(**encoder_args)

    elif args.model in ['pointnet', 'pointnet_ll']:
        assert len(args.encoder_dims) % 2 == 1, 'args.decoder_dims is [' \
                                                'first_hidden_dims, ' \
                                                'middle_dim, ' \
                                                'second_hidden_dims]'
        middle_dim_idx = len(args.encoder_dims) // 2
        first_hidden_dims = args.encoder_dims[:middle_dim_idx]
        second__hidden_dims = args.encoder_dims[middle_dim_idx + 1:]
        middle_dim = args.encoder_dims[middle_dim_idx]

        # PointNet encoder hyperparameters.
        encoder_args = {'out_dim': args.latent_dim,
                        'middle_dim': middle_dim,
                        'first_hidden_dims': first_hidden_dims,
                        'second_hidden_dims': second__hidden_dims,
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.PointNet(**encoder_args)

    elif args.model in ['zeroimputation', 'zeroimputation_ll']:
        # Zero imputation hyperparameters.
        encoder_args = {'in_dim': len(args.observations),
                        'out_dim': args.latent_dim,
                        'hidden_dims': args.encoder_dims,
                        'initial_sigma': .1,
                        'initial_mu': 0.,
                        'min_sigma': 0.01
                        }

        encoder = gpvae.networks.LinearGaussian(**encoder_args)

    else:
        raise ValueError('{} is not a model.'.format(args.model))

    hyperparameters = vars(args)
    hyperparameters['encoder_args'] = encoder_args
    hyperparameters['decoder_args'] = decoder_args

    if args.gpvae:
        model = gpvae.models.GPVAE(encoder, decoder, args.latent_dim,
                                   kernel, add_jitter=args.add_jitter)

        # Train model.
        if args.modified_elbo:
            metrics = train_japan_tridaily(
                model,
                gpvae.estimators.gpvae_estimators.conditional_td_estimator,
                hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator)
        else:
            metrics = train_japan_tridaily(
                model, gpvae.estimators.gpvae_estimators.td_estimator,
                hyperparameters,
                gpvae.estimators.gpvae_estimators.elbo_estimator)

    else:
        model = gpvae.models.VAE(encoder, decoder, args.latent_dim)

        # Train model.
        metrics = train_japan_tridaily(
            model, gpvae.estimators.vae_estimators.td_estimator,
            hyperparameters,
            gpvae.estimators.vae_estimators.elbo_estimator)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # General.
    parser.add_argument('--results_dir', default='./_tridaily_results/')
    parser.add_argument('--elev', default=True, type=str2bool)
    parser.add_argument('--evaluate_1980', default=True, type=str2bool)
    parser.add_argument('--evaluate_1981', default=True, type=str2bool)
    parser.add_argument('--save_model', default=True, type=str2bool)
    parser.add_argument('--prcp', default=True, type=str2bool)
    parser.add_argument('--snwd', default=True, type=str2bool)

    # Kernel.
    parser.add_argument('--init_lengthscale', default=0.5, type=float)
    parser.add_argument('--init_scale', default=1., type=float)

    # GPVAE.
    parser.add_argument('--gpvae', default=True, type=str2bool)
    parser.add_argument('--model', default='indexnet')
    parser.add_argument('--latent_dim', default=4, type=int)
    parser.add_argument('--decoder_dims', default=[20, 20], nargs='+',
                        type=int)
    parser.add_argument('--complex_decoder', default=False, type=str2bool)
    parser.add_argument('--sigma', default=0.1, type=float)
    parser.add_argument('--train_sigma', default=True, type=str2bool)
    parser.add_argument('--encoder_dims', default=[50, 50, 50, 50, 50],
                        nargs='+', type=int)
    parser.add_argument('--linear_likelihood', default=False, type=str2bool)
    parser.add_argument('--add_jitter', default=True, type=str2bool)

    # Training.
    parser.add_argument('--epochs', default=21, type=int)
    parser.add_argument('--cache_freq', default=5, type=int)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--num_samples', default=1, type=int)
    parser.add_argument('--elbo_samples', default=100, type=int)
    parser.add_argument('--test_samples', default=100, type=int)
    parser.add_argument('--modified_elbo', default=False, type=str2bool)
    parser.add_argument('--p', default=0.2, type=float)

    args = parser.parse_args()
    main(args)