import sys
sys.path.append('../../')

import os
import pickle
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import gpvae

from data.eeg import load
from gpvae.utils import metric_utils
from tqdm import tqdm

__all__ = ['train_eeg']

# Mean-field models.
mf_models = [gpvae.models.TitsiasSparseGPVAE]

_, train, test = load()
y_mean = train.mean().to_numpy()
y_std = train.std().to_numpy()


def train_eeg(model, loss_fn, loader, args, elbo_estimator=None,
              iwae_estimator=None, normalised=True, save_model=True):
    metrics = {'epochs': [],
               'losses': [],
               'elbos': [],
               'iwaes': [],
               'smses': [],
               'smlls': [],
               'mlls': []
               }

    model.train(True)
    optimiser = optim.Adam(model.parameters(), lr=args['lr'])

    # Get dataset.
    dataset = loader.dataset.dataset()
    if loader.dataset.contains_nan:
        x, y, m, idx = dataset
    else:
        x, y, idx = dataset
        m = None

    # Training.
    for epoch in tqdm(range(args['epochs'])):
        epoch_losses = []
        for i, batch in enumerate(loader):
            if loader.dataset.contains_nan:
                x_b, y_b, m_b, idx_b = batch
            else:
                x_b, y_b, idx_b = batch
                m_b = None

            optimiser.zero_grad()

            if type(model) in mf_models:
                loss = loss_fn(
                    model, x=x_b, y=y_b, mask=m_b, num_samples=1,
                    decoder_scale=args['decoder_scale'], mf=True, idx=idx_b)
            else:
                loss = loss_fn(
                    model, x=x_b, y=y_b, mask=m_b, num_samples=1,
                    decoder_scale=args['decoder_scale'])

            loss.backward()
            optimiser.step()

            epoch_losses.append(loss.item())

        # Evaluate model.
        if (epoch % args['cache_freq'] == 0) or (epoch == args['epochs'] - 1):
            model.eval()

            report = 'Epoch {}\n'.format(epoch)

            # Average loss over previous epoch.
            mean_loss = np.mean(epoch_losses)
            metrics['losses'].append(mean_loss)
            report += 'Loss: {:.3f}\n'.format(mean_loss)

            if elbo_estimator is not None:
                # ELBO estimate.
                if type(model) in mf_models:
                    elbo = elbo_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'],
                        mf=True, idx=idx)
                else:
                    elbo = elbo_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'])

                metrics['elbos'].append(elbo)
                report += 'ELBO: {:.3f}\n'.format(elbo)

            if iwae_estimator is not None:
                # IWAE estimate.
                if type(model) in mf_models:
                    iwae = iwae_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'],
                        mf=True, idx=idx)
                else:
                    iwae = iwae_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'])

                metrics['iwaes'].append(iwae)
                report += 'IWAE: {:.3f}\n'.format(iwae)

            if test is not None:
                # Test predictions.
                if type(model) in mf_models:
                    mean, sigma = model.predict_y(
                        x=x, idx=idx, num_samples=args['test_samples'])[:2]
                else:
                    mean, sigma = model.predict_y(
                        x=x, y=y, mask=m, num_samples=args['test_samples'])[:2]

                mean, sigma = mean.numpy(), sigma.numpy()

                if normalised:
                    mean = mean * y_std + y_mean
                    sigma = sigma * y_std

                # Evaluate test predictions.
                pred = pd.DataFrame(mean, index=train.index,
                                    columns=train.columns)
                var = pd.DataFrame(sigma ** 2, index=train.index,
                                   columns=train.columns)

                smse = metric_utils.smse(pred, test).mean()
                smll = metric_utils.smll(pred, var, test).mean()
                mll = metric_utils.mll(pred, var, test).mean()

                metrics['smses'].append(smse)
                metrics['smlls'].append(smll)
                metrics['mlls'].append(mll)
                report += 'SMSE: {:.3f}\n'.format(smse)
                report += 'SMLL: {:.3f}\n'.format(smll)
                report += 'MLL: {:.3f}\n'.format(mll)

            # Report model performance.
            tqdm.write(report)

            model.train(True)

    if save_model:
        # Save model, hyperparameters and metrics.
        save(model, args, metrics)

    return metrics


def save(model, args, metrics):
    if 'model' not in args:
        print("Error: 'model' does not exist in args. Aborting save.")
        return

    if 'results_dir' in args.keys():
        results_dir = args['results_dir'] + args['model']
    else:
        results_dir = '_results/' + args['model']
    if os.path.isdir(results_dir):
        i = 1
        while os.path.isdir(results_dir + '_' + str(i)):
            i += 1

        results_dir = results_dir + '_' + str(i)

    os.makedirs(results_dir)
    results_path = results_dir + '/report.txt'
    model_path = results_dir + '/model_state_dict.pt'

    # Pickle args and metrics.
    with open(results_dir + '/args.pkl', 'wb') as f:
        pickle.dump(args, f)

    with open(results_dir + '/metrics.pkl', 'wb') as f:
        pickle.dump(metrics, f)

    # Save args and results in text format.
    with open(results_path, 'w') as f:
        f.write('Args: \n')
        if isinstance(args, list):
            for d in args:
                f.write(str(d) + '\n')
        else:
            f.write(str(args) + '\n')

        f.write('\nPerformance: \n')
        for (key, values) in metrics.items():
            try:
                f.write('{}: {}\n'.format(key, values[-1]))
            except IndexError:
                pass

    # Save model.state_dict().
    torch.save(model.state_dict(), model_path)
