#!/usr/bin/env python

"""
For usage, run `train.py -h` and read README.org
"""

import os
import shutil
import math
import argparse
import errno
import random
import numpy
import torch
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.accelerators import find_usable_cuda_devices
import numpy as np
import lifted_pddl
import json

from typing import Optional

from util import *
from constants import (
    HEURISTICS,
    COSTS,
    DOMAIN_TO_GENERATOR,
    TRAINING_LOGS_DIR,
    TRAINING_CKPT_DIR,
    TRAINING_JSON_DIR,
)

from pddlsl.datamodule import (
    CommonDataModule,
    NLMDataModule,
    RRDataModule,
    HGNDataModule,
)
from pddlsl.model      import (
    HeuristicLearner,
    HeuristicLearnerNLM,
    HeuristicLearnerRR,
    HeuristicLearnerHGN,
)

def parse_arguments():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="It trains a model and stores the test losses into a JSON file.")
    parser.add_argument('-o', '--output', default="data", help="root directory for storing the generated problem and data.")
    parser.add_argument('--seed', default=42, type=int, help="""Seed for the experiment.""")
    parser.add_argument('--train-size', type=int, default=400, help="""The number of problems to use for training.""")
    parser.add_argument('--sigma', default='learn', choices=("learn","fixed"))
    parser.add_argument('--target', default='h*', choices=COSTS)
    parser.add_argument('--l', choices=COSTS, help="If specified, overrides --l-train and --l-test.")
    parser.add_argument('--l-train', default='ninf', choices=COSTS)
    parser.add_argument('--l-test', default='ninf', choices=COSTS)
    parser.add_argument('--u', choices=COSTS, help="If specified, overrides --u-train and --u-test.")
    parser.add_argument('--u-train', default='inf', choices=COSTS)
    parser.add_argument('--u-test', default='inf', choices=COSTS)
    parser.add_argument('--res', choices=COSTS, help="If specified, overrides --res-train and --res-test.")
    parser.add_argument('--res-train', default='zero', choices=COSTS)
    parser.add_argument('--res-test', default='zero', choices=COSTS)
    parser.add_argument('--l-as-input', action="store_true", help="feed the lower bound as an additional nullary input.")
    parser.add_argument('-f', '--force', action="store_true", help="overwrite an existing test result (json file).")
    parser.add_argument('-v', '--verbose', action="store_true")
    parser.add_argument('--lr', type=float, default=1e-3, help="learning rate")
    parser.add_argument('--decay', type=float, default=0.0, help="weight decay")
    parser.add_argument('--clip', type=float, default=0.1, help="gradient clipping")
    parser.add_argument('--deterministic', action="store_true", help="use slow deterministic training")
    parser.add_argument('--batch-size', type=int, default=256, help="batch size")
    parser.add_argument('--steps', type=int, default=10000, help="maximum training update steps")
    parser.add_argument('--compute', choices=("auto","cpu"), default="auto",
                        help=("compute mode. "
                              "'cpu' forces the CPU mode. "
                              "'auto' tries to use a GPU, and falls back to the CPU mode when failed."))
    parser.add_argument('--mode', choices=("train","test","both"), default="both",
                        help=("behavior. "
                              "'train' and 'test' are self-explanatory. "
                              "'both' performs both."))
    parser.add_argument('--if-ckpt-exists',
                        choices=("supersede","resume","error","skip"),
                        default="skip",
                        help=("define the behavior when a checkpoint exists in the training mode. "
                              "supersede: remove the existing checkpoint and create a new one. "
                              "resume: resume the training from the existing checkpoint. "
                              "error: quit immediately raising an error. "
                              "skip: skip the training and proceed to the testing, if necessary. "))
    parser.add_argument('--if-ckpt-does-not-exist',
                        choices=("create","error"),
                        default="create",
                        help=("define the behavior when a checkpoint does not exist in the training mode. "
                              "create: create a new one. "
                              "error: quit immediately raising an error. "))
    parser.add_argument('model', choices=('NLM','RR','HGN'))
    parser.add_argument('dist', choices=('gaussian','truncated'))
    parser.add_argument('domain', choices=DOMAIN_TO_GENERATOR.keys())
    parser.add_argument('rest', nargs="*", help="model-specific options for NLM, RR, HGN")
    args = parser.parse_args()

    if args.l is not None:
        args.l_train = args.l_test = args.l
    if args.u is not None:
        args.u_train = args.u_test = args.u
    if args.res is not None:
        args.res_train = args.res_test = args.res
    if args.l_as_input:
        assert args.l_train != "ninf", "--l-as-input requires a non-trivial training lower bound"
        assert args.l_test != "ninf", "--l-as-input requires a non-trivial testing lower bound"

    if args.model == "NLM":
        args.learner_cls = HeuristicLearnerNLM
        args.dm_cls      = NLMDataModule
    elif args.model == "RR":
        args.learner_cls = HeuristicLearnerRR
        args.dm_cls      = RRDataModule
    elif args.model == "HGN":
        args.learner_cls = HeuristicLearnerHGN
        args.dm_cls      = HGNDataModule
    else:
        raise "huh?"

    args.learner_cls.parse_arguments(args) # parse additional arguments specific to learner

    return args


def train(args, id, dm, training_ckpt_path : Optional[str]):
    if args.verbose:
        print("--- Training started ---")

    model = args.learner_cls(args)

    if args.compute == "cpu" or pl.pytorch_lightning.accelerators.cuda.num_cuda_devices() <= 0:
        if args.compute == "cpu":
            print(f"running in the CPU mode (forced)")
        else:
            print(f"running in the CPU mode (GPU not found)")
        device_options = {"accelerator":"cpu"}
    else: # Train on GPU
        print("running in the GPU mode")
        device_options = {"accelerator":"cuda","devices":find_usable_cuda_devices(1)}

    trainer = pl.Trainer(max_steps               = args.steps,
                         check_val_every_n_epoch = math.ceil(400 / args.train_size),
                         logger                  = TensorBoardLogger(TRAINING_LOGS_DIR, name=id, version=""),
                         callbacks               = [ModelCheckpoint(monitor='v_nll',
                                                                    every_n_epochs=1,
                                                                    dirpath=TRAINING_CKPT_DIR,
                                                                    filename=id,
                                                                    mode='min',
                                                                    save_top_k=1,
                                                                    save_on_train_epoch_end=False),
                                                    ModelCheckpoint(monitor='v_mse',
                                                                    every_n_epochs=1,
                                                                    dirpath=TRAINING_CKPT_DIR,
                                                                    filename=id+"-mse",
                                                                    mode='min',
                                                                    save_top_k=1,
                                                                    save_on_train_epoch_end=False)],
                         deterministic           = args.deterministic,
                         gradient_clip_val       = args.clip,
                         **device_options)

    # note: this restores the training step. max_step works as expected.
    # note2: we always resume from the best v_nll checkpoint, not the best v_mse checkpoint
    trainer.fit(model, datamodule=dm, ckpt_path=training_ckpt_path)
    if args.verbose:
        print("--- Training finished ---")
    pass


def load(args, training_ckpt_path):
    model = args.learner_cls.load_from_checkpoint(training_ckpt_path)

    if args.compute == "cpu" or pl.pytorch_lightning.accelerators.cuda.num_cuda_devices() <= 0:
        if args.compute == "cpu":
            print(f"running in the CPU mode (forced)")
        else:
            print(f"running in the CPU mode (GPU not found)")
        trainer = pl.Trainer(accelerator='cpu',
                             deterministic=args.deterministic)
    else: # Train on GPU
        print("running in the GPU mode")
        trainer = pl.Trainer(accelerator='cuda',
                             devices=find_usable_cuda_devices(1),
                             deterministic=args.deterministic)

    return model, trainer


def test(args, dm, training_json_path, training_ckpt_path):
    if args.verbose:
        print("--- Test started ---")

    def fn(path):
        model, trainer = load(args, path)
        losses = trainer.test(model=model, datamodule=dm)[0]
        step = model.persistent_global_step.item()
        return losses, step

    best_nll_losses, best_nll_step = fn(training_ckpt_path)
    best_mse_losses, best_mse_step = fn(append_to_name(training_ckpt_path, "-mse"))

    if args.verbose:
        print("--- Test finished ---")
        print("> Training step that achived the best validation NLL:", best_nll_step)
        print("> Its checkpoint path:", training_ckpt_path)
        print("> Training step that achived the best validation MSE:", best_mse_step)
        print("> Its checkpoint path:", append_to_name(training_ckpt_path, "-mse"))

    save_json(
        {
            **{k:v for k,v in vars(args).items() if k not in {"learner_cls","dm_cls"}},
            **{'best_nll_'+k:v for k,v in best_nll_losses.items()},
            'best_nll_step': best_nll_step,
            'best_nll_path': training_ckpt_path,
            **{'best_mse_'+k:v for k,v in best_mse_losses.items()},
            'best_mse_step': best_mse_step,
            'best_mse_path': append_to_name(training_ckpt_path, "-mse"),
        },
        training_json_path)

    if args.verbose:
        print("--- Test saved ---")


def main(args):
    os.chdir(os.path.dirname(os.path.realpath(__file__)))
    seed_everything(args.seed, workers=True)
    id = args.learner_cls.id(args)

    training_json_path = os.path.join(TRAINING_JSON_DIR, f'{id}.json')
    training_logs_path = os.path.join(TRAINING_LOGS_DIR, f'{id}')
    training_ckpt_path = os.path.join(TRAINING_CKPT_DIR, f'{id}.ckpt')

    dm = args.dm_cls(args)

    if args.mode in {"train", "both"}:
        if os.path.exists(training_ckpt_path):
            if args.if_ckpt_exists == "supersede":
                print(f"checkpoint exist, rerunning the training: {training_ckpt_path}")
                os.remove(training_ckpt_path)
                train(args, id, dm, None)
            elif args.if_ckpt_exists == "resume":
                print(f"checkpoint exist, resuming the training: {training_ckpt_path}")
                train(args, id, dm, training_ckpt_path)
            elif args.if_ckpt_exists == "error":
                raise f"checkpoint should not exist: {training_ckpt_path}"
            elif args.if_ckpt_exists == "skip":
                print(f"checkpoint exist, skipping the training: {training_ckpt_path}")
            else:
                raise "huh?"
        else:
            if args.if_ckpt_does_not_exist == "create":
                train(args, id, dm, None)
            elif args.if_ckpt_does_not_exist == "error":
                raise f"checkpoint should exist: {training_ckpt_path}"
            else:
                raise "huh?"

    if args.mode in {"test", "both"} and \
       target_required(training_json_path,
                       force=args.force,
                       verbose=args.verbose):

        test(args, dm,
             training_json_path,
             training_ckpt_path)

    pass


if __name__ == '__main__':
    args = parse_arguments()
    try:
        main(args)
    except:
        import pddlsl.stacktrace
        pddlsl.stacktrace.format(arraytypes={numpy.ndarray,torch.Tensor},include_self=False)


