import sys
sys.path.append('../../')

import time
import os
import pickle
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import gpvae
import data.japan_weekly
import data.japan_weekly2
import data.japan_weekly3
import data.japan_monthly

from gpvae.utils import metric_utils
from tqdm import tqdm

__all__ = ['train_japan_weekly', 'preprocess']


def train_japan_weekly(model, loss_fn, args, elbo_estimator=None,
                       iwae_estimator=None):

    # Set observations.
    observations = ['PRCP', 'TMAX', 'TMIN', 'TAVG', 'SNWD']

    all_data = preprocess(observations, args)
    metrics = _train(model, loss_fn, args, elbo_estimator, iwae_estimator,
                     all_data, observations)

    if args['save_model']:
        # Save model, hyperparameters and metrics.
        _save(model, args, metrics)

    return metrics


def _train(model, loss_fn, args, elbo_estimator, iwae_estimator, data,
           observations):
    # Metrics to cache.
    metrics = {'epochs': [],
               'losses': [],
               'rmses_1980': [],
               'mlls_1980': [],
               'rmses_1981': [],
               'mlls_1981': [],
               'times': []
               }
    t0 = time.time()

    for col in observations:
        metrics[col + '_rmses_1980'] = []
        metrics[col + '_mlls_1980'] = []
        metrics[col + '_rmses_1981'] = []
        metrics[col + '_mlls_1981'] = []

    model.train(True)
    optimiser = optim.Adam(model.parameters(), lr=args['lr'])

    # Set up dataset and dataloaders.
    train_dataset = gpvae.utils.dataset_utils.MetaTupleDataset(
        data['x_train'], data['y_train'], contains_nan=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=False)

    # For model evaluation.
    train_1980_dataset = gpvae.utils.dataset_utils.MetaTupleDataset(
        data['x_train_1980'], data['y_train_1980'], contains_nan=True)
    train_1981_dataset = gpvae.utils.dataset_utils.MetaTupleDataset(
        data['x_train_1981'], data['y_train_1981'], contains_nan=True)

    epoch_iter = tqdm(range(args['epochs']), desc='Epoch', leave=True)
    for epoch in epoch_iter:
        epoch_losses = []
        batch_iter = tqdm(iter(train_loader), desc='Batch', leave=False)
        for x_b, y_b, m_b, idx_b in batch_iter:
            # Get rid of 3rd-dimension.
            x_b = x_b.squeeze(0)
            y_b = y_b.squeeze(0)
            m_b = m_b.squeeze(0)

            if args['batch_size'] is not None:
                batch_loss = 0
                batch_dataset = gpvae.utils.dataset_utils.TupleDataset(
                    x_b, y_b, contains_nan=True)
                batch_dataset.m = m_b
                batch_loader = torch.utils.data.DataLoader(
                    batch_dataset, batch_size=args['batch_size'])

                for x_mb, y_mb, m_mb, idx_mb in iter(batch_loader):
                    optimiser.zero_grad()

                    # Compute gradient of negative ELBO.
                    loss = loss_fn(model, x=x_mb, y=y_mb, mask=m_mb,
                                   num_samples=1)

                    try:
                        loss.backward()
                        optimiser.step()
                    except RuntimeError as e:
                        # A bit sketch.
                        pass

                    batch_loss += loss.item()

                epoch_losses.append(batch_loss)
                batch_iter.set_postfix(loss=batch_loss)

            else:
                optimiser.zero_grad()

                # Compute gradient of negative ELBO.
                loss = loss_fn(model, x=x_b, y=y_b, mask=m_b, num_samples=1)

                try:
                    loss.backward()
                    optimiser.step()
                except RuntimeError as e:
                    # A bit sketch.
                    pass

                epoch_losses.append(loss.item())
                batch_iter.set_postfix(loss=loss.item())

        epoch_iter.set_postfix(loss=np.mean(epoch_losses))

        # Evaluate model.
        if ((epoch > 0 and epoch % args['cache_freq'] == 0) or
                epoch == args['epochs'] - 1):
            tqdm.write('Evaluating model performance...')
            model.eval()

            # Time taken so far.
            metrics['times'].append(time.time() - t0)

            # Average loss over previous epoch.
            mean_loss = np.mean(epoch_losses)
            tqdm.write('Epoch {}\n Loss: {:.3f}\n'.format(epoch, mean_loss))
            metrics['epochs'].append(epoch)
            metrics['losses'].append(mean_loss)

            if args['evaluate_1980'] or (epoch == args['epochs'] - 1):
                tqdm.write('Evaluating 1980 performance.')
                # Evaluate 1980 validation performance.
                pred, var = _predict(
                    model, train_1980_dataset, data['test_1980'],
                    observations, data['y_mean'], data['y_std'])

                rmse = metric_utils.rmse(pred, pd.concat(data['test_1980']))
                mll = metric_utils.mll(pred, var, pd.concat(data['test_1980']))
                for col in observations:
                    tqdm.write('{}: RMSE = {:.3f} MLL = {:.3f}'.format(
                        col, rmse[col], mll[col]))
                    metrics[col + '_rmses_1980'].append(rmse[col])
                    metrics[col + '_mlls_1980'].append(mll[col])

                metrics['rmses_1980'].append(rmse.mean())
                metrics['mlls_1980'].append(mll.mean())

            if args['evaluate_1981'] or (epoch == args['epochs'] - 1):
                tqdm.write('Evaluating 1981 performance.')
                # Evaluate 1981 validation performance.
                pred, var = _predict(
                    model, train_1981_dataset, data['test_1981'],
                    observations, data['y_mean'], data['y_std'])

                rmse = metric_utils.rmse(pred, pd.concat(data['test_1981']))
                mll = metric_utils.mll(pred, var, pd.concat(data['test_1981']))
                for col in observations:
                    tqdm.write('{}: RMSE = {:.3f} MLL = {:.3f}'.format(
                        col, rmse[col], mll[col]))
                    metrics[col + '_rmses_1981'].append(rmse[col])
                    metrics[col + '_mlls_1981'].append(mll[col])

                metrics['rmses_1981'].append(rmse.mean())
                metrics['mlls_1981'].append(mll.mean())

            model.train(True)

    return metrics


def _predict(model, train_dataset, per_test, observations, y_mean, y_std):
    per_pred = []
    per_var = []

    eval_iter = tqdm(zip(train_dataset.x, train_dataset.y, train_dataset.m,
                         per_test), desc='Evaluation')
    for x_b, y_b, m_b, df_b in eval_iter:
        x_b = x_b.squeeze(0)
        y_b = y_b.squeeze(0)
        m_b = m_b.squeeze(0)

        # Test predictions.
        mean, sigma = model.predict_y(
            x=x_b, y=y_b, mask=m_b, num_samples=10)[:2]

        mean = mean.numpy() * y_std + y_mean
        sigma = sigma.numpy() * y_std

        # Convert to DataFrame and add to lists.
        pred = pd.DataFrame(mean, index=df_b.index, columns=observations)
        var = pd.DataFrame(sigma**2, index=df_b.index, columns=observations)

        per_pred.append(pred)
        per_var.append(var)

    pred = pd.concat(per_pred)
    var = pd.concat(per_var)

    return pred, var


def _save(model, args, metrics):
    if 'model' not in args:
        print("Error: 'model' does not exist in args. Aborting save.")
        return

    if 'results_dir' in args.keys():
        results_dir = args['results_dir'] + args['model']
    else:
        results_dir = '_results/' + args['model']

    if os.path.isdir(results_dir):
        i = 1
        while os.path.isdir(results_dir + '_' + str(i)):
            i += 1

        results_dir = results_dir + '_' + str(i)

    os.makedirs(results_dir)
    results_path = results_dir + '/report.txt'
    model_path = results_dir + '/model_state_dict.pt'

    # Pickle args and metrics.
    with open(results_dir + '/args.pkl', 'wb') as f:
        pickle.dump(args, f)

    with open(results_dir + '/metrics.pkl', 'wb') as f:
        pickle.dump(metrics, f)

    # Save args and results in text format.
    with open(results_path, 'w') as f:
        f.write('Args: \n')
        if isinstance(args, list):
            for d in args:
                f.write(str(d) + '\n')
        else:
            f.write(str(args) + '\n')

        f.write('\nPerformance: \n')
        for (key, values) in metrics.items():
            try:
                f.write('{}: {}\n'.format(key, values[-1]))
            except IndexError:
                pass

    # Save model.state_dict().
    torch.save(model.state_dict(), model_path)


def preprocess(observations, args):
    if args['experiment'] == 1:
        all_data = data.japan_weekly.load()
    elif args['experiment'] == 2:
        all_data = data.japan_weekly2.load()
    elif args['experiment'] == 3:
        all_data = data.japan_weekly3.load()
    elif args['experiment'] == 4:
        all_data = data.japan_monthly.load()
    else:
        raise ValueError('Only experiment 1, 2 or 3 available.')

    df = all_data['all']
    train = all_data['train']
    train_1980 = all_data['train_1980']
    test_1980 = all_data['test_1980']
    train_1981 = all_data['train_1981']
    test_1981 = all_data['test_1981']

    if args['elev']:
        inputs = ['lat', 'lon', 'elev', 'time']
    else:
        inputs = ['lat', 'lon', 'time']

    observations = ['PRCP', 'TMAX', 'TMIN', 'TAVG', 'SNWD']

    # Extract data into numpy arrays and perform preprocessing.
    x_train = np.array(list(map(lambda x: x[inputs].to_numpy(), train)))
    y_train = np.array(list(map(lambda x: x[observations].to_numpy(), train)))

    # 1980 data..
    x_train_1980 = np.array(list(map(lambda x: x[inputs].to_numpy(),
                                     train_1980)))
    y_train_1980 = np.array(list(map(lambda x: x[observations].to_numpy(),
                                     train_1980)))
    x_test_1980 = np.array(list(map(lambda x: x[inputs].to_numpy(),
                                    test_1980)))

    # 1981 data.
    x_train_1981 = np.array(list(map(lambda x: x[inputs].to_numpy(),
                                     train_1981)))
    y_train_1981 = np.array(list(map(lambda x: x[observations].to_numpy(),
                                     train_1981)))
    x_test_1981 = np.array(list(map(lambda x: x[inputs].to_numpy(),
                                    test_1981)))

    # Normalise data.
    y_mean = df[observations].mean().to_numpy()
    y_std = df[observations].std().to_numpy()
    x_mean = df[inputs].mean().to_numpy()
    x_std = df[inputs].std().to_numpy()

    y_train = list(map(lambda x: (x - y_mean) / y_std, y_train))
    x_train = list(map(lambda x: (x - x_mean) / x_std, x_train))
    y_train_1980 = list(map(lambda x: (x - y_mean) / y_std, y_train_1980))
    x_train_1980 = list(map(lambda x: (x - x_mean) / x_std, x_train_1980))
    y_train_1981 = list(map(lambda x: (x - y_mean) / y_std, y_train_1981))
    x_train_1981 = list(map(lambda x: (x - x_mean) / x_std, x_train_1981))
    x_test_1980 = list(map(lambda x: (x - x_mean) / x_std, x_test_1980))
    x_test_1981 = list(map(lambda x: (x - x_mean) / x_std, x_test_1981))

    # Convert to tensors.
    x_train = list(map(lambda x: torch.tensor(x), x_train))
    y_train = list(map(lambda x: torch.tensor(x), y_train))
    x_train_1980 = list(map(lambda x: torch.tensor(x), x_train_1980))
    y_train_1980 = list(map(lambda x: torch.tensor(x), y_train_1980))
    x_train_1981 = list(map(lambda x: torch.tensor(x), x_train_1981))
    y_train_1981 = list(map(lambda x: torch.tensor(x), y_train_1981))
    x_test_1980 = list(map(lambda x: torch.tensor(x), x_test_1980))
    x_test_1981 = list(map(lambda x: torch.tensor(x), x_test_1981))

    preprocessed_data = {'x_train': x_train,
                         'y_train': y_train,
                         'x_train_1980': x_train_1980,
                         'y_train_1980': y_train_1980,
                         'x_train_1981': x_train_1981,
                         'y_train_1981': y_train_1981,
                         'x_test_1980': x_test_1980,
                         'x_test_1981': x_test_1981,
                         'test_1980': test_1980,
                         'test_1981': test_1981,
                         'y_mean': y_mean,
                         'y_std': y_std,
                         'x_mean': x_mean,
                         'x_std': x_std
                         }

    return preprocessed_data
