import sys
sys.path.append('../../')

import copy
import os
import pdb
import pickle
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import gpvae

from data.eeg import load
from gpvae.utils import metric_utils
from tqdm import tqdm

__all__ = ['train_eeg']

# Mean-field models.
mf_models = [gpvae.models.TitsiasSparseGPVAE]

# Load from Wessel's wbml package.
_, train, test = load()
y_mean = train.mean().to_numpy()
y_std = train.std().to_numpy()


def train_eeg(model, loss_fn, loader, args, elbo_estimator=None,
              iwae_estimator=None, normalised=True, save_model=True, p=0.1):
    metrics = {'epochs': [],
               'losses': [],
               'elbos': [],
               'iwaes': [],
               'smses': [],
               'smlls': [],
               'mlls': []
               }

    model.train(True)
    optimiser = optim.Adam(model.parameters(), lr=args['lr'])

    # Get dataset.
    dataset = loader.dataset.dataset()
    if loader.dataset.contains_nan:
        x, y, m, idx = dataset
    else:
        x, y, idx = dataset
        m = None

    # Training.
    for epoch in tqdm(range(args['epochs'])):
        epoch_losses = []
        for i, batch in enumerate(loader):
            if loader.dataset.contains_nan:
                x_b, y_b, m_b, idx_b = batch
            else:
                x_b, y_b, idx_b = batch
                m_b = None

            # Set random proportion of batch to missing.
            y_c_b = copy.deepcopy(y_b)
            set_to_nan = np.random.choice(a=[False, True], size=y_c_b.shape,
                                          p=[1-p, p])
            y_c_b[set_to_nan] = 0.
            m_c_b = copy.deepcopy(m_b)
            m_c_b[set_to_nan] = False

            has_values = ~(torch.sum(m_c_b, 1) == 0)
            x_b = x_b[has_values]
            y_b = y_b[has_values]
            m_b = m_b[has_values]
            idx_b = idx_b[has_values]
            y_c_b = y_c_b[has_values]
            m_c_b = m_c_b[has_values]

            try:
                optimiser.zero_grad()

                if type(model) in mf_models:
                    loss = loss_fn(
                        model, x=x_b, y=y_b, y_c=y_c_b, mask=m_b,
                        mask_c=m_c_b, num_samples=1,
                        decoder_scale=args['decoder_scale'], mf=True, idx=idx_b)
                else:
                    loss = loss_fn(
                        model, x=x_b, y=y_b, y_c=y_c_b, mask=m_b,
                        mask_c=m_c_b, num_samples=1,
                        decoder_scale=args['decoder_scale'])

                loss.backward()
                optimiser.step()

                epoch_losses.append(loss.item())
            except:
                pass

        # Evaluate model.
        if (epoch % args['cache_freq'] == 0) or (epoch == args['epochs'] - 1):
            model.eval()

            report = 'Epoch {}\n'.format(epoch)

            # Average loss over previous epoch.
            mean_loss = np.mean(epoch_losses)
            metrics['losses'].append(mean_loss)
            report += 'Loss: {:.3f}\n'.format(mean_loss)

            if elbo_estimator is not None:
                # ELBO estimate.
                if type(model) in mf_models:
                    elbo = elbo_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'],
                        mf=True, idx=idx)
                else:
                    elbo = elbo_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'])

                metrics['elbos'].append(elbo)
                report += 'ELBO: {:.3f}\n'.format(elbo)

            if iwae_estimator is not None:
                # IWAE estimate.
                if type(model) in mf_models:
                    iwae = iwae_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'],
                        mf=True, idx=idx)
                else:
                    iwae = iwae_estimator(
                        model, x, y, mask=m, num_samples=args['elbo_samples'])

                metrics['iwaes'].append(iwae)
                report += 'IWAE: {:.3f}\n'.format(iwae)

            if test is not None:
                # Test predictions.
                if type(model) in mf_models:
                    mean, sigma = model.predict_y(
                        x=x, idx=idx, num_samples=args['test_samples'])[:2]
                else:
                    mean, sigma = model.predict_y(
                        x=x, y=y, mask=m, num_samples=args['test_samples'])[:2]

                mean, sigma = mean.numpy(), sigma.numpy()

                if normalised:
                    mean = mean * y_std + y_mean
                    sigma = sigma * y_std

                # Evaluate test predictions.
                pred = pd.DataFrame(mean, index=train.index,
                                    columns=train.columns)
                var = pd.DataFrame(sigma ** 2, index=train.index,
                                   columns=train.columns)

                smse = metric_utils.smse(pred, test).mean()
                smll = metric_utils.smll(pred, var, test).mean()
                mll = metric_utils.mll(pred, var, test).mean()

                metrics['smses'].append(smse)
                metrics['smlls'].append(smll)
                metrics['mlls'].append(mll)
                report += 'SMSE: {:.3f}\n'.format(smse)
                report += 'SMLL: {:.3f}\n'.format(smll)
                report += 'MLL: {:.3f}\n'.format(mll)

            # Report model performance.
            tqdm.write(report)

            model.train(True)

    if save_model:
        # Save model, hyperparameters and metrics.
        save(model, args, metrics)

    return metrics


def save(model, args, metrics):
    if 'model' not in args:
        print("Error: 'model' does not exist in args. Aborting save.")
        return

    if 'results_dir' in args.keys():
        results_dir = args['results_dir'] + args['model']
    else:
        results_dir = '_conditional_results/' + args['model']
    if os.path.isdir(results_dir):
        i = 1
        while os.path.isdir(results_dir + '_' + str(i)):
            i += 1

        results_dir = results_dir + '_' + str(i)

    os.makedirs(results_dir)
    results_path = results_dir + '/report.txt'
    model_path = results_dir + '/model_state_dict.pt'

    # Pickle args and metrics.
    with open(results_dir + '/args.pkl', 'wb') as f:
        pickle.dump(args, f)

    with open(results_dir + '/metrics.pkl', 'wb') as f:
        pickle.dump(metrics, f)

    # Save args and results in text format.
    with open(results_path, 'w') as f:
        f.write('Args: \n')
        if isinstance(args, list):
            for d in args:
                f.write(str(d) + '\n')
        else:
            f.write(str(args) + '\n')

        f.write('\nPerformance: \n')
        for (key, values) in metrics.items():
            try:
                f.write('{}: {}\n'.format(key, values[-1]))
            except IndexError:
                pass

    # Save model.state_dict().
    torch.save(model.state_dict(), model_path)
