import sys
sys.path.append('../../')

import argparse
import numpy as np
import torch
import gpvae
import pickle
import pandas as pd

from gpvae.utils import metric_utils
from data.eeg2 import load
torch.set_default_dtype(torch.float64)

# Mean-field models.
mf_models = [gpvae.models.TitsiasSparseGPVAE]


def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main(args):
    ######################
    # Data preprocessing #
    ######################
    _, train, test = load()

    # Extract data into numpy arrays.
    x = np.array(train.index)
    y = np.array(train)
    x_test = np.array(test.index)
    decoder_scale = 1. / (np.sum(~np.isnan(y)) / (y.shape[0] * y.shape[1]))
    setattr(args, 'decoder_scale', decoder_scale)

    # Normalise observations.
    y_mean, y_std = np.nanmean(y, axis=0), np.nanstd(y, axis=0)
    y = (y - y_mean) / y_std

    # Convert to tensors.
    x = torch.tensor(x)
    y = torch.tensor(y)
    x_test = torch.tensor(x_test).unsqueeze(1)

    # Set up dataset.
    dataset = gpvae.utils.dataset_utils.TupleDataset(x, y, contains_nan=True)

    # Load model.
    model, old_args, old_metrics = _load_model(args.model_path)

    # ELBO estimate and test predictions.
    if not old_args['gpvae']:
        elbo = gpvae.estimators.vae_estimators.elbo_estimator(
            model, dataset.x, dataset.y, mask=dataset.m,
            num_samples=args.elbo_samples)
        mean, sigma = model.predict_y(
            x=x_test, num_samples=args.test_samples)[:2]
    else:
        elbo = gpvae.estimators.gpvae_estimators.elbo_estimator(
            model, dataset.x, dataset.y, mask=dataset.m,
            num_samples=args.elbo_samples)
        mean, sigma = model.predict_y(
            x=dataset.x, y=dataset.y, mask=dataset.m, x_test=x_test,
            num_samples=args.test_samples)[:2]

    mean, sigma = mean.numpy(), sigma.numpy()

    mean = mean * y_std + y_mean
    sigma = sigma * y_std

    # Evaluate test predictions.
    pred = pd.DataFrame(mean, index=test.index,
                        columns=test.columns)
    var = pd.DataFrame(sigma ** 2, index=test.index,
                       columns=test.columns)

    smse = metric_utils.smse(pred, test).mean()
    mll = metric_utils.mll(pred, var, test).mean()
    print('ELBO: {:.4f}\nSMSE: {:.4f}\nMLL: {:.4f}'.format(elbo, smse, mll))

    metrics = {'elbo': elbo,
               'smse': smse,
               'mll': mll}

    # Save args and results in text format.
    results_path = args.model_path + '/new_task_report.txt'
    with open(results_path, 'w') as f:
        f.write('\nPerformance: \n')
        for (key, value) in metrics.items():
            try:
                f.write('{}: {}\n'.format(key, value))
            except IndexError:
                pass


def _load_model(path):
    # State dict.
    state_dict = torch.load(path + '/model_state_dict.pt')

    # Hyperparameters.
    with open(path + '/args.pkl', 'rb') as f:
        args = pickle.load(f)

    # Metrics.
    with open(path + '/metrics.pkl', 'rb') as f:
        metrics = pickle.load(f)

    if args['linear_likelihood']:
        decoder = gpvae.networks.AffineGaussian(**args['decoder_args'])
    else:
        decoder = gpvae.networks.LinearGaussian(**args['decoder_args'])

    if ('periodic_kernel' in args.keys()) and args['periodic_kernel']:
        k1 = gpvae.kernels.RBFKernel(
            lengthscale=args['init_lengthscale'], scale=args['init_scale']/2)
        k2 = gpvae.kernels.PeriodicKernel(
            lengthscale=args['init_lengthscale'], period=args['init_period'],
            scale=args['init_scale']/2)
        kernel = gpvae.kernels.AdditiveKernel(k1, k2)
    else:
        kernel = gpvae.kernels.RBFKernel(
            lengthscale=args['init_lengthscale'], scale=args['init_scale']/2)

    if args['model'] in ['pog', 'pog_ll', 'pog_vae']:
        encoder = gpvae.networks.FactorNet(
            **args['encoder_args'])

        if args['gpvae']:
            model = gpvae.models.GPVAE(encoder, decoder, args['latent_dim'],
                                       kernel=kernel)
            model.load_state_dict(state_dict)
        else:
            model = gpvae.models.VAE(encoder, decoder, args['latent_dim'])
            model.load_state_dict(state_dict)

    elif args['model'] in ['indexnet', 'indexnet_ll', 'indexnet_vae']:
        encoder = gpvae.networks.IndexNet(
            **args['encoder_args'])

        if args['gpvae']:
            model = gpvae.models.GPVAE(encoder, decoder, args['latent_dim'],
                                       kernel=kernel)
            model.load_state_dict(state_dict)
        else:
            model = gpvae.models.VAE(encoder, decoder, args['latent_dim'])
            model.load_state_dict(state_dict)

    elif args['model'] in ['pointnet', 'pointnet_ll', 'pointnet_vae']:
        encoder = gpvae.networks.PointNet(**args['encoder_args'])

        if args['gpvae']:
            model = gpvae.models.GPVAE(encoder, decoder, args['latent_dim'],
                                       kernel=kernel)
            model.load_state_dict(state_dict)
        else:
            model = gpvae.models.VAE(encoder, decoder, args['latent_dim'])
            model.load_state_dict(state_dict)

    elif args['model'] in ['zeroimputation', 'zeroimputation_ll',
                           'zeroimputation_vae']:
        encoder = gpvae.networks.LinearGaussian(**args['encoder_args'])

        if args['gpvae']:
            model = gpvae.models.GPVAE(encoder, decoder, args['latent_dim'],
                                       kernel=kernel)
            model.load_state_dict(state_dict)
        else:
            model = gpvae.models.VAE(encoder, decoder, args['latent_dim'])
            model.load_state_dict(state_dict)

    else:
        raise ValueError('Model name not correct.')

    return model, args, metrics


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # General.
    parser.add_argument('--model_path', default='./_results/K3/indexnet')

    # Training.
    parser.add_argument('--elbo_samples', default=100, type=int)
    parser.add_argument('--test_samples', default=100, type=int)

    args = parser.parse_args()
    main(args)
