import sys, copy, time
sys.path.append('../../')

import os
import pickle
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import gpvae

from gpvae.utils import metric_utils
from tqdm import tqdm

__all__ = ['train_big_eeg', 'train_big_eeg_modified']

# Mean-field models.
mf_models = [gpvae.models.TitsiasSparseGPVAE]


def train_big_eeg(model, loss_fn, train_loader, test_loader, eval_test,
                  eval_train_loader, eval_eval_train, y_mean, y_std, args,
                  save_model=True):
    metrics = {'epochs': [],
               'losses': [],
               'smses': [],
               'mlls': [],
               'train_smses': [],
               'train_mlls': [],
               'times': []
               }
    t0 = time.time()

    model.train(True)
    optimiser = optim.Adam(model.parameters(), lr=args['lr'])

    # Training.
    optimiser.zero_grad()
    epoch_iter = tqdm(range(args['epochs']), desc='Epoch', leave=True)
    for epoch in epoch_iter:
        epoch_losses = []
        loader_iter = tqdm(iter(train_loader), desc='Iter', leave=False)
        for x_b, y_b, m_b, idx_b in loader_iter:
            # Get rid of 3rd-dimension.
            x_b = x_b.squeeze(0)
            y_b = y_b.squeeze(0)
            m_b = m_b.squeeze(0)

            has_values = ~(torch.sum(m_b, 1) == 0)
            x_b = x_b[has_values]
            y_b = y_b[has_values]
            m_b = m_b[has_values]

            optimiser.zero_grad()

            # Compute gradient of negative ELBO.
            loss = loss_fn(model, x=x_b, y=y_b, mask=m_b, num_samples=1)

            try:
                loss.backward()
                optimiser.step()
            except RuntimeError:
                pass

            epoch_losses.append(loss.item())

        # Average loss over previous epoch.
        mean_loss = np.mean(epoch_losses)
        epoch_iter.set_postfix({'Loss': mean_loss})

        # Evaluate model.
        if (epoch % args['cache_freq'] == 0) or (epoch == args['epochs'] - 1):
            model.eval()

            report = 'Epoch {}\n'.format(epoch)

            # Time taken so far.
            metrics['times'].append(time.time() - t0)

            # Average loss over previous epoch.
            metrics['losses'].append(mean_loss)
            report += 'Loss: {:.3f}\n'.format(mean_loss)

            if eval_test is not None:
                test_smse, test_mll = _predict(model, test_loader, eval_test,
                                               y_mean, y_std, args)

                metrics['smses'].append(test_smse)
                metrics['mlls'].append(test_mll)
                report += 'SMSE: {:.3f}\n'.format(test_smse)
                report += 'MLL: {:.3f}\n'.format(test_mll)

            if eval_eval_train is not None:
                train_smse, train_mll = _predict(model, eval_train_loader,
                                                 eval_eval_train,
                                                 y_mean, y_std, args)

                metrics['train_smses'].append(train_smse)
                metrics['train_mlls'].append(train_mll)
                report += 'Train SMSE: {:.3f}\n'.format(train_smse)
                report += 'Test MLL: {:.3f}\n'.format(train_mll)

            # Report model performance.
            tqdm.write(report)

            model.train(True)

    if save_model:
        # Save model, hyperparameters and metrics.
        _save(model, args, metrics)

    return metrics


def train_big_eeg_modified(model, loss_fn, train_loader, test_loader,
                           eval_test, y_mean, y_std, args, save_model=True):
    metrics = {'epochs': [],
               'losses': [],
               'smses': [],
               'mlls': [],
               'times': []
               }

    t0 = time.time()

    model.train(True)
    optimiser = optim.Adam(model.parameters(), lr=args['lr'])

    # Training.
    optimiser.zero_grad()
    epoch_iter = tqdm(range(args['epochs']), desc='Epoch', leave=True)
    for epoch in epoch_iter:
        epoch_losses = []
        loader_iter = tqdm(iter(train_loader), desc='Iter', leave=False)
        for x_b, y_b, m_b, idx_b in loader_iter:
            # Get rid of 3rd-dimension.
            x_b = x_b.squeeze(0)
            y_b = y_b.squeeze(0)
            m_b = m_b.squeeze(0)

            # Set random proportion of batch to missing.
            y_c_b = copy.deepcopy(y_b)
            set_to_nan = np.random.choice(a=[False, True], size=y_c_b.shape,
                                          p=[1 - args['p'], args['p']])
            y_c_b[set_to_nan] = 0.
            m_c_b = copy.deepcopy(m_b)
            m_c_b[set_to_nan] = False

            has_values = ~(torch.sum(m_c_b, 1) == 0)
            x_b = x_b[has_values]
            y_b = y_b[has_values]
            m_b = m_b[has_values]
            y_c_b = y_c_b[has_values]
            m_c_b = m_c_b[has_values]

            optimiser.zero_grad()

            try:
                # Compute gradient of negative ELBO.
                loss = loss_fn(model, x=x_b, y=y_b, y_c=y_c_b, mask=m_b,
                               mask_c=m_c_b, num_samples=1)
                loss.backward()
                optimiser.step()
                epoch_losses.append(loss.item())
            except RuntimeError:
                pass

        # Average loss over previous epoch.
        mean_loss = np.mean(epoch_losses)
        epoch_iter.set_postfix({'Loss': mean_loss})

        # Evaluate model.
        if (epoch % args['cache_freq'] == 0) or (epoch == args['epochs'] - 1):
            model.eval()

            report = 'Epoch {}\n'.format(epoch)

            # Time taken so far.
            metrics['times'].append(time.time() - t0)

            # Average loss over previous epoch.
            metrics['losses'].append(mean_loss)
            report += 'Loss: {:.3f}\n'.format(mean_loss)

            if eval_test is not None:
                smse, mll = _predict(model, test_loader, eval_test, y_mean,
                                     y_std, args)

                metrics['smses'].append(smse)
                metrics['mlls'].append(mll)
                report += 'SMSE: {:.3f}\n'.format(smse)
                report += 'MLL: {:.3f}\n'.format(mll)

            # Report model performance.
            tqdm.write(report)

            model.train(True)

    if save_model:
        # Save model, hyperparameters and metrics.
        _save(model, args, metrics)

    return metrics


def _save(model, args, metrics):
    if 'model' not in args:
        print("Error: 'model' does not exist in args. Aborting save.")
        return

    if 'results_dir' in args.keys():
        results_dir = args['results_dir'] + args['model']
    else:
        results_dir = '_results/' + args['model']
    if os.path.isdir(results_dir):
        i = 1
        while os.path.isdir(results_dir + '_' + str(i)):
            i += 1

        results_dir = results_dir + '_' + str(i)

    os.makedirs(results_dir)
    results_path = results_dir + '/report.txt'
    model_path = results_dir + '/model_state_dict.pt'

    # Pickle args and metrics.
    with open(results_dir + '/args.pkl', 'wb') as f:
        pickle.dump(args, f)

    with open(results_dir + '/metrics.pkl', 'wb') as f:
        pickle.dump(metrics, f)

    # Save args and results in text format.
    with open(results_path, 'w') as f:
        f.write('Args: \n')
        if isinstance(args, list):
            for d in args:
                f.write(str(d) + '\n')
        else:
            f.write(str(args) + '\n')

        f.write('\nPerformance: \n')
        for (key, values) in metrics.items():
            try:
                f.write('{}: {}\n'.format(key, values[-1]))
            except IndexError:
                pass

    # Save model.state_dict().
    torch.save(model.state_dict(), model_path)


def _predict(model, loader, df, y_mean, y_std, args):
    # Test predictions.
    pred = []
    var = []

    test_iter = tqdm(zip(iter(loader), df),
                     desc='Test', leave=False)
    for (x_b, y_b, m_b, idx_b), test_df in test_iter:
        x_b = x_b.squeeze(0)
        y_b = y_b.squeeze(0)
        m_b = m_b.squeeze(0)

        has_values = ~(torch.sum(m_b, 1) == 0)
        x_b_ = x_b[has_values]
        y_b = y_b[has_values]
        m_b = m_b[has_values]
        # index = test_df.index[has_values]

        # Test predictions.
        if type(model) == gpvae.models.VAE:
            mean, sigma = model.predict_y(
                x=x_b, num_samples=args['test_samples'])[:2]
        else:
            mean, sigma = model.predict_y(
                x=x_b_, y=y_b, mask=m_b, x_test=x_b,
                num_samples=args['test_samples'])[:2]

        mean = mean.numpy() * y_std + y_mean
        sigma = sigma.numpy() * y_std

        # Convert to DataFrame and add to lists.
        pred.append(pd.DataFrame(mean, index=test_df.index,
                                 columns=args['observations']))
        var.append(pd.DataFrame(sigma ** 2, index=test_df.index,
                                columns=args['observations']))
    # Evaluate test predictions.
    smses = np.array([metric_utils.smse(p, t).values
                      for p, t in zip(pred, df)])
    mlls = np.array([metric_utils.mll(p, v, t).values
                     for p, v, t in zip(pred, var, df)])

    smse = np.nanmean(smses)
    mll = np.nanmean(mlls)

    return smse, mll
