###########################
# Latent ODEs for Irregularly-Sampled Time Series
# Author: Yulia Rubanova
###########################


"""
try:
        import sys

        sys.path.insert(0, '/home/soraxas/git-repo/pytoolbox/')

        import soraxas_toolbox

        printer = soraxas_toolbox.torch.TorchNetworkPrinter(
                ignore_modules=['_PerturbFunc'],
                ignore_modules_within=['LVField', 'GRU_unit', 'ODEFunc'],
        )
except Exception as e:
        # doesn't matter
        pass
        raise e
"""


import os
import sys
import matplotlib
matplotlib.use('Agg')
#matplotlib.use('tkAgg')
import matplotlib.pyplot
import matplotlib.pyplot as plt

import time
import datetime
import argparse
import numpy as np
import pandas as pd
from random import SystemRandom
from sklearn import model_selection

import torch
import torch.nn as nn
from torch.nn.functional import relu
import torch.optim as optim

import lib.utils as utils
from lib.plotting import *

from lib.rnn_baselines import *
from lib.ode_rnn import *
from lib.create_latent_ode_model import create_LatentODE_model
from lib.parse_datasets import parse_datasets
from lib.ode_func import ODEFunc, ODEFunc_w_Poisson
from lib.diffeq_solver import DiffeqSolver
from mujoco_physics import HopperPhysics

from lib.utils import compute_loss_all_batches

# Generative model for noisy data based on ODE
parser = argparse.ArgumentParser('Latent ODE')
parser.add_argument('-n',  type=int, default=100, help="Size of the dataset")
parser.add_argument('--niters', type=int, default=300)
parser.add_argument('--lr',  type=float, default=1e-2, help="Starting learning rate.")
parser.add_argument('-b', '--batch-size', type=int, default=50)
parser.add_argument('--viz', action='store_true', help="Show plots while training")

parser.add_argument('--save', type=str, default='experiments/', help="Path for save checkpoints")
parser.add_argument('--load', type=str, default=None, help="ID of the experiment to load for evaluation. If None, run a new experiment.")
parser.add_argument('-r', '--random-seed', type=int, default=1991, help="Random_seed")

parser.add_argument('--dataset', type=str, default='periodic', help="Dataset to load. Available: physionet, activity, hopper, periodic")
parser.add_argument('-s', '--sample-tp', type=float, default=None, help="Number of time points to sub-sample."
        "If > 1, subsample exact number of points. If the number is in [0,1], take a percentage of available points per time series. If None, do not subsample")

parser.add_argument('-c', '--cut-tp', type=int, default=None, help="Cut out the section of the timeline of the specified length (in number of points)."
        "Used for periodic function demo.")

parser.add_argument('--quantization', type=float, default=0.1, help="Quantization on the physionet dataset."
        "Value 1 means quantization by 1 hour, value 0.1 means quantization by 0.1 hour = 6 min")

parser.add_argument('--latent-ode', action='store_true', help="Run Latent ODE seq2seq model")
parser.add_argument('--z0-encoder', type=str, default='odernn', help="Type of encoder for Latent ODE model: odernn or rnn")

parser.add_argument('--classic-rnn', action='store_true', help="Run RNN baseline: classic RNN that sees true points at every point. Used for interpolation only.")
parser.add_argument('--rnn-cell', default="gru", help="RNN Cell type. Available: gru (default), expdecay")
parser.add_argument('--input-decay', action='store_true', help="For RNN: use the input that is the weighted average of impirical mean and previous value (like in GRU-D)")

parser.add_argument('--ode-rnn', action='store_true', help="Run ODE-RNN baseline: RNN-style that sees true points at every point. Used for interpolation only.")

parser.add_argument('--rnn-vae', action='store_true', help="Run RNN baseline: seq2seq model with sampling of the h0 and ELBO loss.")

parser.add_argument('-l', '--latents', type=int, default=6, help="Size of the latent state")
parser.add_argument('--rec-dims', type=int, default=20, help="Dimensionality of the recognition model (ODE or RNN).")

parser.add_argument('--rec-layers', type=int, default=1, help="Number of layers in ODE func in recognition ODE")
parser.add_argument('--gen-layers', type=int, default=1, help="Number of layers in ODE func in generative ODE")

parser.add_argument('-u', '--units', type=int, default=100, help="Number of units per layer in ODE func")
parser.add_argument('-g', '--gru-units', type=int, default=100, help="Number of units per layer in each of GRU update networks")

parser.add_argument('--poisson', action='store_true', help="Model poisson-process likelihood for the density of events in addition to reconstruction.")
parser.add_argument('--classif', action='store_true', help="Include binary classification loss -- used for Physionet dataset for hospiral mortality")

parser.add_argument('--linear-classif', action='store_true', help="If using a classifier, use a linear classifier instead of 1-layer NN")
parser.add_argument('--extrap', action='store_true', help="Set extrapolation mode. If this flag is not set, run interpolation mode.")

parser.add_argument('-t', '--timepoints', type=int, default=100, help="Total number of time-points")
parser.add_argument('--max-t',  type=float, default=5., help="We subsample points in the interval [0, args.max_tp]")
parser.add_argument('--noise-weight', type=float, default=0.01, help="Noise amplitude for generated traejctories")


args = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
file_name = os.path.basename(__file__)[:-3]
utils.makedirs(args.save)

#####################################################################################################

from torch.utils.tensorboard import SummaryWriter


def consolidate_best_results(master_dict, new_dict, highest=set(), lowest=set()):
        """
        Given a master dict, and a new dict, consolidate all new results from new_dict to master_dict.

        By default, lowest result in new_dict will replace results in master_dict.
        Unless the key exists in highest, of which the lower value in new_dict will replaces master_dict.

        Operate in-place in master_dict.
        """
        for k, v in new_dict.items():
                if k not in master_dict:
                        master_dict[k] = v
                else:
                        operator = max if k in highest else min
                        master_dict[k] = operator(v, master_dict[k])

def add_all_results_to_tensorboard_writer(writer, results, iter):
        for k, v in results.items():
                writer.add_scalar(k, v, iter)


if __name__ == '__main__':

        ALL_SETTINGS = []

        for method in [
                'ours', 
                'node'
                ]:
                available_timepts = [100]
                if args.dataset == 'periodic':
                    available_timepts = [100, 1000]
                for timepts in available_timepts:
                    if method == 'ours':
                            for _hid_dim, _num_lay, _aug_dim in (
                                    [1000,2,5],
                               #     [200, 5, 5],
                               #     [1000, 8, 5],
                               #     [1500, 10, 5],
                            ):
                                    ALL_SETTINGS.append(dict(
                                            model=method,

                                            dataset=args.dataset,
                                            timepoints=timepts,

                                            ours__inn_hidden_dim=_hid_dim,
                                            ours__num_layers=_num_lay,
                                            ours__extra_augnmented_dim=_aug_dim,
                                            ))
                    elif method == 'node':
                            for int_method in [
                                    'dopri5', 
                                    'euler', 
                                    'rk4',
                                    'dopri8', 
                                    'midpoint', 
                                    ]:
                                    ALL_SETTINGS.append(dict(
                                            model=method,
        
                                            dataset=args.dataset,
                                            timepoints=timepts,
                                            #timepoints=args.timepoints,
        
                                            node__int_method=int_method,
                                            ))

        for BLOB_OF_MODEL_SETTINGS in ALL_SETTINGS:
                args.timepoints = BLOB_OF_MODEL_SETTINGS['timepoints']
                print("=" * 50)
                print(BLOB_OF_MODEL_SETTINGS)
                print("=" * 50)

                settings_str = '|'.join(f'{k}={v}' for (k, v) in sorted(BLOB_OF_MODEL_SETTINGS.items()))

                torch.manual_seed(args.random_seed)
                np.random.seed(args.random_seed)

                experimentID = args.load
                if experimentID is None:
                        # Make a new experiment ID
                        experimentID = int(SystemRandom().random()*100000)

                identity_str = f"[{settings_str}]_{experimentID}"


                ########################################################
                writer = SummaryWriter(comment=identity_str)
                # writer.add_hparams(BLOB_OF_MODEL_SETTINGS, {})
                ########################################################

                ckpt_path = os.path.join(args.save, f'experiment_{identity_str}')

                start = time.time()
                print("Sampling dataset of {} training examples".format(args.n))

                input_command = sys.argv
                ind = [i for i in range(len(input_command)) if input_command[i] == "--load"]
                if len(ind) == 1:
                        ind = ind[0]
                        input_command = input_command[:ind] + input_command[(ind+2):]
                input_command = " ".join(input_command)

                utils.makedirs("results/")

                ##################################################################
                data_obj = parse_datasets(args, device)
                input_dim = data_obj["input_dim"]

                classif_per_tp = False
                if ("classif_per_tp" in data_obj):
                        # do classification per time point rather than on a time series as a whole
                        classif_per_tp = data_obj["classif_per_tp"]

                if args.classif and (args.dataset == "hopper" or args.dataset == "periodic"):
                        raise Exception("Classification task is not available for MuJoCo and 1d datasets")

                n_labels = 1
                if args.classif:
                        if ("n_labels" in data_obj):
                                n_labels = data_obj["n_labels"]
                        else:
                                raise Exception("Please provide number of labels for classification task")

                ##################################################################
                # Create the model
                obsrv_std = 0.01
                if args.dataset == "hopper":
                        obsrv_std = 1e-3

                obsrv_std = torch.Tensor([obsrv_std]).to(device)

                z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))


                if args.rnn_vae:
                        if args.poisson:
                                print("Poisson process likelihood not implemented for RNN-VAE: ignoring --poisson")

                        # Create RNN-VAE model
                        model = RNN_VAE(input_dim, args.latents,
                                device = device,
                                rec_dims = args.rec_dims,
                                concat_mask = True,
                                obsrv_std = obsrv_std,
                                z0_prior = z0_prior,
                                use_binary_classif = args.classif,
                                classif_per_tp = classif_per_tp,
                                linear_classifier = args.linear_classif,
                                n_units = args.units,
                                input_space_decay = args.input_decay,
                                cell = args.rnn_cell,
                                n_labels = n_labels,
                                train_classif_w_reconstr = (args.dataset == "physionet")
                                ).to(device)


                elif args.classic_rnn:
                        if args.poisson:
                                print("Poisson process likelihood not implemented for RNN: ignoring --poisson")

                        if args.extrap:
                                raise Exception("Extrapolation for standard RNN not implemented")
                        # Create RNN model
                        model = Classic_RNN(input_dim, args.latents, device,
                                concat_mask = True, obsrv_std = obsrv_std,
                                n_units = args.units,
                                use_binary_classif = args.classif,
                                classif_per_tp = classif_per_tp,
                                linear_classifier = args.linear_classif,
                                input_space_decay = args.input_decay,
                                cell = args.rnn_cell,
                                n_labels = n_labels,
                                train_classif_w_reconstr = (args.dataset == "physionet")
                                ).to(device)
                elif args.ode_rnn:
                        # Create ODE-GRU model
                        n_ode_gru_dims = args.latents

                        if args.poisson:
                                print("Poisson process likelihood not implemented for ODE-RNN: ignoring --poisson")

                        if args.extrap:
                                raise Exception("Extrapolation for ODE-RNN not implemented")

                        ode_func_net = utils.create_net(n_ode_gru_dims, n_ode_gru_dims,
                                n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh)

                        rec_ode_func = ODEFunc(
                                input_dim = input_dim,
                                latent_dim = n_ode_gru_dims,
                                ode_func_net = ode_func_net,
                                device = device).to(device)

                        z0_diffeq_solver = DiffeqSolver(input_dim, rec_ode_func, "euler", args.latents,
                                odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)

                        model = ODE_RNN(input_dim, n_ode_gru_dims, device = device,
                                z0_diffeq_solver = z0_diffeq_solver, n_gru_units = args.gru_units,
                                concat_mask = True, obsrv_std = obsrv_std,
                                use_binary_classif = args.classif,
                                classif_per_tp = classif_per_tp,
                                n_labels = n_labels,
                                train_classif_w_reconstr = (args.dataset == "physionet")
                                ).to(device)
                elif args.latent_ode:
                        model = create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device,
                                classif_per_tp = classif_per_tp,
                                n_labels = n_labels,
                                BLOB_OF_MODEL_SETTINGS=BLOB_OF_MODEL_SETTINGS,
                                )
                else:
                        raise Exception("Model not specified")

                ##################################################################

                if args.viz:
                        viz = Visualizations(device)

                ##################################################################

                #Load checkpoint and evaluate the model
                if args.load is not None:
                        utils.get_ckpt_model(ckpt_path, model, device)
                        exit()

                ##################################################################
                # Training

                log_path = "logs/" + file_name + "_" + str(experimentID) + ".log"
                if not os.path.exists("logs/"):
                        utils.makedirs("logs/")
                logger = utils.get_logger(logpath=log_path, filepath=os.path.abspath(__file__))
                logger.info(input_command)

                optimizer = optim.Adamax(model.parameters(), lr=args.lr)

                num_batches = data_obj["n_train_batches"]

                from tqdm import trange

                weight = None
                if 'class_weight' in data_obj:
                    weight = data_obj['class_weight']


                best_results = {}
                with trange(1, num_batches * (args.niters + 1)) as pbar:
                        for itr in pbar:
                                optimizer.zero_grad()
                                utils.update_learning_rate(optimizer, decay_rate = 0.999, lowest = args.lr / 10)

                                wait_until_kl_inc = 10
                                if itr // num_batches < wait_until_kl_inc:
                                        kl_coef = 0.
                                else:
                                        kl_coef = (1-0.99** (itr // num_batches - wait_until_kl_inc))

                                batch_dict = utils.get_next_batch(data_obj["train_dataloader"])
                                train_res = model.compute_all_losses(batch_dict, n_traj_samples = 3, kl_coef = kl_coef, weight=weight)
                                train_res["loss"].backward()
                                pbar.set_postfix(
                                        loss = train_res["loss"].detach().item()
                                )
                                writer.add_scalar('train/Loss', train_res["loss"], itr)
                                optimizer.step()

                                n_iters_to_viz = 1
                                if itr % (n_iters_to_viz * num_batches) == 0:
                                        with torch.no_grad():

                                                test_res = compute_loss_all_batches(model,
                                                        data_obj["test_dataloader"], args,
                                                        n_batches = data_obj["n_test_batches"],
                                                        experimentID = experimentID,
                                                        device = device,
                                                        n_traj_samples = 1, kl_coef = kl_coef)

                                                message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
                                                        itr//num_batches,
                                                        test_res["loss"].detach(), test_res["likelihood"].detach(),
                                                        test_res["kl_first_p"], test_res["std_first_p"])


                                                _results = {
                                                        'train/Loss': train_res["loss"],
                                                        'train/CE_loss': train_res["ce_loss"],
                                                        'test/Loss': test_res["loss"],
                                                        'test/Likelihood': test_res["likelihood"],
                                                        'test/KL_fp': test_res["kl_first_p"],
                                                        'test/FP_STD': test_res["std_first_p"],
                                                        'test/kl_coef': kl_coef
                                                }


                                                # logger.info("Experiment " + str(experimentID))
                                                # logger.info(message)
                                                # logger.info("KL coef: {}".format(kl_coef))
                                                # logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
                                                # logger.info("Train CE loss (one batch): {}".format(train_res["ce_loss"].detach()))

                                                if "auc" in test_res:
                                                        _results['test/AUC'] = test_res["auc"]
                                                        # logger.info("Classification AUC (TEST): {:.4f}".format(test_res["auc"]))

                                                if "mse" in test_res:
                                                        _results['test/MSE'] = test_res["mse"]
                                                        # logger.info("Test MSE: {:.4f}".format(test_res["mse"]))

                                                if "accuracy" in train_res:
                                                        _results['train/accuracy'] = train_res["accuracy"]
                                                        # logger.info("Classification accuracy (TRAIN): {:.4f}".format(train_res["accuracy"]))

                                                if "accuracy" in test_res:
                                                        _results['test/accuracy'] = test_res["accuracy"]
                                                        # logger.info("Classification accuracy (TEST): {:.4f}".format(test_res["accuracy"]))

                                                if "pois_likelihood" in test_res:
                                                        _results['test/Poisson_likelihood'] = test_res["pois_likelihood"]
                                                        # logger.info("Poisson likelihood: {}".format(test_res["pois_likelihood"]))

                                                if "ce_loss" in test_res:
                                                        _results['test/CE_lost'] = test_res["ce_loss"]
                                                        # logger.info("CE loss: {}".format(test_res["ce_loss"]))


                                                add_all_results_to_tensorboard_writer(writer, _results, itr)
                                                consolidate_best_results(best_results, _results, highest={
                                                        'test/Likelihood',
                                                        'test/AUC',
                                                        'train/accuracy',
                                                        'test/accuracy',
                                                        'test/Poisson_likelihood',
                                                })

                                        torch.save({
                                                'args': args,
                                                'state_dict': model.state_dict(),
                                        }, f'{ckpt_path}-it{itr}.ckpt')


                                        # Plotting
                                        if args.viz:
                                                with torch.no_grad():
                                                        test_dict = utils.get_next_batch(data_obj["test_dataloader"])

                                                        # print("plotting....")
                                                        if isinstance(model, LatentODE) and (args.dataset == "periodic"): #and not args.classic_rnn and not args.ode_rnn:
                                                                plot_id = itr // num_batches // n_iters_to_viz
                                                                i = 0
                                                                _fname = f"{file_name}_{identity_str}_{plot_id:03d}"
                                                                while os.path.isfile(f"{_fname}_{i}.png"):
                                                                        i += 1
                                                                viz.draw_all_plots_one_dim(test_dict, model,
                                                                        plot_name = f"{_fname}_{i}.png",
                                                                        experimentID = experimentID, save=True)
                                                                # plt.pause(0.01)
                                                                # plt.show()

                                        # break

                writer.add_hparams(BLOB_OF_MODEL_SETTINGS, best_results)

                torch.save({
                        'args': args,
                        'state_dict': model.state_dict(),
                }, f'{ckpt_path}-full.ckpt')

