import sys
sys.path.append('../../')

import pdb
import time
import os
import pickle
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import gpvae

from gpvae.utils import metric_utils
from data.japan_weekly import load
from experiments.weather.train_japan_weekly import preprocess
from tqdm import tqdm

__all__ = ['train_japan_week']

# Mean-field models.
mf_models = [gpvae.models.TitsiasSparseGPVAE]


def train_japan_week(model, loss_fn, args, elbo_estimator=None):

    # Set observations.
    observations = ['PRCP', 'TMAX', 'TMIN', 'TAVG', 'SNWD']

    all_data = preprocess(observations, args['elev'])
    metrics = _train(model, loss_fn, args, elbo_estimator, all_data,
                     observations)

    if args['save_model']:
        # Save model, hyperparameters and metrics.
        _save(model, args, metrics)

    return metrics


def _train(model, loss_fn, args, elbo_estimator, data, observations):
    # Metrics to cache.
    metrics = {'epochs': [],
               'losses': [],
               'elbos': [],
               'rmses_1980': [],
               'mlls_1980': [],
               'rmses_1981': [],
               'mlls_1981': [],
               'times': []
               }
    t0 = time.time()

    for col in observations:
        metrics[col + '_rmses_1980'] = []
        metrics[col + '_mlls_1980'] = []
        metrics[col + '_rmses_1981'] = []
        metrics[col + '_mlls_1981'] = []

    model.train(True)
    optimiser = optim.Adam(model.parameters(), lr=args['lr'])

    # Set up dataset and dataloaders.
    train_dataset = gpvae.utils.dataset_utils.TupleDataset(
        data['x_train'][args['week']], data['y_train'][args['week']],
        contains_nan=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args['batch_size'], shuffle=True)

    # For model evaluation.
    train_1980_dataset = gpvae.utils.dataset_utils.TupleDataset(
        data['x_train_1980'][args['week']], data['y_train_1980'][args['week']],
        contains_nan=True)
    train_1981_dataset = gpvae.utils.dataset_utils.TupleDataset(
        data['x_train_1981'][args['week']], data['y_train_1981'][args['week']],
        contains_nan=True)

    # Get datasets.
    if train_dataset.contains_nan:
        x, y, m, idx = train_dataset.dataset()
    else:
        x, y, idx = train_dataset.dataset()
        m = None

    epoch_iter = tqdm(range(args['epochs']), desc='Epoch', leave=True)
    for epoch in epoch_iter:
        epoch_losses = []
        batch_iter = tqdm(iter(train_loader), desc='Batch', leave=False)
        for x_b, y_b, m_b, idx_b in batch_iter:
            optimiser.zero_grad()

            # Compute gradient of negative ELBO.
            loss = loss_fn(model, x=x_b, y=y_b, mask=m_b, num_samples=1)

            try:
                loss.backward()
                optimiser.step()
            except RuntimeError as e:
                # A bit sketch.
                pass

            epoch_losses.append(loss.item())
            batch_iter.set_postfix(loss=loss.item())

        epoch_iter.set_postfix(loss=np.mean(epoch_losses))

        # Evaluate model.
        if (epoch % args['cache_freq']) == 0 or (epoch == args['epochs'] - 1):
            tqdm.write('Evaluating model performance...')
            model.eval()

            # Time taken so far.
            metrics['times'].append(time.time() - t0)

            # Average loss over previous epoch.
            mean_loss = np.mean(epoch_losses)
            tqdm.write('Epoch {}\n Loss: {:.3f}\n'.format(epoch, mean_loss))
            metrics['epochs'].append(epoch)
            metrics['losses'].append(mean_loss)

            if elbo_estimator is not None:
                # ELBO estimate.
                if type(model) in mf_models:
                    elbo = elbo_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'],
                        mf=True, idx=idx)
                else:
                    elbo = elbo_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'])

                metrics['elbos'].append(elbo)
                tqdm.write('ELBO: {:.3f}\n'.format(elbo))

            if args['evaluate_1980'] or (epoch == args['epochs'] - 1):
                tqdm.write('Evaluating 1980 performance.')
                # Evaluate 1980 validation performance.
                pred, var = _predict(
                    model, train_1980_dataset, data['test_1980'][args['week']],
                    observations, data['y_mean'], data['y_std'])

                pdb.set_trace()

                rmse = metric_utils.rmse(pred, data['test_1980'][args['week']])
                mll = metric_utils.mll(pred, var, data['test_1980'][args['week']])
                for col in observations:
                    tqdm.write('{}: RMSE = {:.3f} MLL = {:.3f}'.format(
                        col, rmse[col], mll[col]))
                    metrics[col + '_rmses_1980'].append(rmse[col])
                    metrics[col + '_mlls_1980'].append(mll[col])

                metrics['rmses_1980'].append(rmse.mean())
                metrics['mlls_1980'].append(mll.mean())

            if args['evaluate_1981'] or (epoch == args['epochs'] - 1):
                tqdm.write('Evaluating 1981 performance.')
                # Evaluate 1981 validation performance.
                pred, var = _predict(
                    model, train_1981_dataset, data['test_1981'][args['week']],
                    observations, data['y_mean'], data['y_std'])

                rmse = metric_utils.rmse(pred, data['test_1981'][args['week']])
                mll = metric_utils.mll(pred, var, data['test_1981'][args['week']])
                for col in observations:
                    tqdm.write('{}: RMSE = {:.3f} MLL = {:.3f}'.format(
                        col, rmse[col], mll[col]))
                    metrics[col + '_rmses_1981'].append(rmse[col])
                    metrics[col + '_mlls_1981'].append(mll[col])

                metrics['rmses_1981'].append(rmse.mean())
                metrics['mlls_1981'].append(mll.mean())

            model.train(True)

    return metrics


def _predict(model, train_dataset, df_test, observations, y_mean, y_std):

    if train_dataset.contains_nan:
        x, y, m, idx = train_dataset.dataset()
    else:
        x, y, idx = train_dataset.dataset()
        m = None

    # Test predictions.
    if type(model) in mf_models:
        mean, sigma = model.predict_y(x=x, idx=idx, num_samples=10)[:2]
    else:
        mean, sigma = model.predict_y(x=x, y=y, mask=m, num_samples=10)[:2]

    mean = mean.numpy() * y_std + y_mean
    sigma = sigma.numpy() * y_std

    # Convert to DataFrame and add to lists.
    pred = pd.DataFrame(mean, index=df_test.index, columns=observations)
    var = pd.DataFrame(sigma ** 2, index=df_test.index, columns=observations)

    return pred, var


def _save(model, args, metrics):
    if 'model' not in args:
        print("Error: 'model' does not exist in args. Aborting save.")
        return

    if 'results_dir' in args.keys():
        results_dir = args['results_dir'] + args['model']
    else:
        results_dir = '_results/' + args['model']

    if os.path.isdir(results_dir):
        i = 1
        while os.path.isdir(results_dir + '_' + str(i)):
            i += 1

        results_dir = results_dir + '_' + str(i)

    os.makedirs(results_dir)
    results_path = results_dir + '/report.txt'
    model_path = results_dir + '/model_state_dict.pt'

    # Pickle args and metrics.
    with open(results_dir + '/args.pkl', 'wb') as f:
        pickle.dump(args, f)

    with open(results_dir + '/metrics.pkl', 'wb') as f:
        pickle.dump(metrics, f)

    # Save args and results in text format.
    with open(results_path, 'w') as f:
        f.write('Args: \n')
        if isinstance(args, list):
            for d in args:
                f.write(str(d) + '\n')
        else:
            f.write(str(args) + '\n')

        f.write('\nPerformance: \n')
        for (key, values) in metrics.items():
            try:
                f.write('{}: {}\n'.format(key, values[-1]))
            except IndexError:
                pass

    # Save model.state_dict().
    torch.save(model.state_dict(), model_path)


def preprocess(observations, inputs, elev=True):
    all_data = load()
    df = all_data['all']
    train = all_data['train']
    train_1980 = all_data['train_1980']
    test_1980 = all_data['test_1980']
    train_1981 = all_data['train_1981']
    test_1981 = all_data['test_1981']

    if elev:
        inputs = ['lat', 'lon', 'elev', 'time']
    else:
        inputs = ['lat', 'lon', 'time']

    observations = ['PRCP', 'TMAX', 'TMIN', 'TAVG', 'SNWD']

    # Extract data into numpy arrays and perform preprocessing.
    x_train = np.array(list(map(lambda x: x[inputs].to_numpy(), train)))
    y_train = np.array(list(map(lambda x: x[observations].to_numpy(), train)))

    # 1980 data..
    x_train_1980 = np.array(list(map(lambda x: x[inputs].to_numpy(),
                                     train_1980)))
    y_train_1980 = np.array(list(map(lambda x: x[observations].to_numpy(),
                                     train_1980)))
    x_test_1980 = np.array(list(map(lambda x: x[inputs].to_numpy(),
                                    test_1980)))

    # 1981 data.
    x_train_1981 = np.array(list(map(lambda x: x[inputs].to_numpy(),
                                     train_1981)))
    y_train_1981 = np.array(list(map(lambda x: x[observations].to_numpy(),
                                     train_1981)))
    x_test_1981 = np.array(list(map(lambda x: x[inputs].to_numpy(),
                                    test_1981)))

    # Normalise data.
    y_mean = df[observations].mean().to_numpy()
    y_std = df[observations].std().to_numpy()
    x_mean = df[inputs].mean().to_numpy()
    x_std = df[inputs].std().to_numpy()

    y_train = list(map(lambda x: (x - y_mean) / y_std, y_train))
    x_train = list(map(lambda x: (x - x_mean) / x_std, x_train))
    y_train_1980 = list(map(lambda x: (x - y_mean) / y_std, y_train_1980))
    x_train_1980 = list(map(lambda x: (x - x_mean) / x_std, x_train_1980))
    y_train_1981 = list(map(lambda x: (x - y_mean) / y_std, y_train_1981))
    x_train_1981 = list(map(lambda x: (x - x_mean) / x_std, x_train_1981))
    x_test_1980 = list(map(lambda x: (x - x_mean) / x_std, x_test_1980))
    x_test_1981 = list(map(lambda x: (x - x_mean) / x_std, x_test_1981))

    # Convert to tensors.
    x_train = list(map(lambda x: torch.tensor(x), x_train))
    y_train = list(map(lambda x: torch.tensor(x), y_train))
    x_train_1980 = list(map(lambda x: torch.tensor(x), x_train_1980))
    y_train_1980 = list(map(lambda x: torch.tensor(x), y_train_1980))
    x_train_1981 = list(map(lambda x: torch.tensor(x), x_train_1981))
    y_train_1981 = list(map(lambda x: torch.tensor(x), y_train_1981))
    x_test_1980 = list(map(lambda x: torch.tensor(x), x_test_1980))
    x_test_1981 = list(map(lambda x: torch.tensor(x), x_test_1981))

    preprocessed_data = {'x_train': x_train,
                         'y_train': y_train,
                         'x_train_1980': x_train_1980,
                         'y_train_1980': y_train_1980,
                         'x_train_1981': x_train_1981,
                         'y_train_1981': y_train_1981,
                         'x_test_1980': x_test_1980,
                         'x_test_1981': x_test_1981,
                         'test_1980': test_1980,
                         'test_1981': test_1981,
                         'y_mean': y_mean,
                         'y_std': y_std,
                         'x_mean': x_mean,
                         'x_std': x_std
                         }

    return preprocessed_data
