from peagang.experiments.ggan_exp.base import ex, get_model_n_uuid, sacred_copy
from pprint import pprint
from uuid import uuid4
import os
import tempfile
from itertools import chain
from peagang.evaluation.plots.utils.callback import ModelCheckPointWithPlots
from peagang.evaluation.plots.utils.post_experiment_plots import plot_mmd_wrapper
from sacred.observers import FileStorageObserver
from peagang.models.ggg import (
    PEAWGAN,
    PEAWGAN_HyperParameters,forward_clip_hook,backward_trace_hook,backward_clean_hook
)

from peagang.data.dense.PEAWGANDenseData import PEAWGANDenseData
from pytorch_lightning import Trainer

from pytorch_lightning.loggers import TensorBoardLogger
import torch as pt

# TODO(adam): delete this once Sacred issue #498 is resolved
from sacred.run import Run
@ex.main
def run(
        data_dir,
        hyper,
        epochs,
        model_n,
        base_dir,
        overfit_pct,
        ckpt_period,
        detect_anomaly,
        deep,
        deep_gen,
        deep_disc,
        _run: Run,
        _config,
        forward_clip,
):
    pprint(_config)
    data_dir, hyper, epochs, model_n, base_dir, overfit_pct = [
        sacred_copy(o)
        for o in [
            data_dir,
            hyper,
            epochs,
            model_n,
            base_dir,
            overfit_pct,
        ]
    ]
    if model_n is None:
        model_n = get_model_n_uuid(hyper)
    filename = None
    print("Calculating node_dist")
    node_count_weights = PEAWGANDenseData(
        data_dir=data_dir,
        filename=filename,
        dataset=hyper["dataset"],
        inner_kwargs=hyper["dataset_kwargs"]
    ).node_dist_weights()
    exp_name = _run.observers[0].dir
    pyl_log_dir = os.path.join(exp_name, "lightning_log")

    old_hyper = hyper
    changes = dict(
        node_count_weights=node_count_weights,
        data_dir=data_dir,
        filename=filename,
        base_dir=base_dir,
        model_n=model_n,
        deep=deep,
        deep_disc=deep_disc,
        deep_gen=deep_gen,
        save_dir=pyl_log_dir
    )
    # finalize updated hparams
    hparams = {k: changes[k] if k in changes else old_hyper[k] for k in chain(old_hyper.keys(), changes.keys())}

    print("Initializing model with parameters")
    pprint(hparams)
    model = PEAWGAN(hparams)
    if forward_clip:
        for name, module in model.named_modules():
            module.register_forward_hook(forward_clip_hook)
            module.register_backward_hook(backward_trace_hook)

    try:
        # try to use a logger with a nicer folder structure
        from pytorch_lightning.loggers import TestTubeLogger

        tblogger = TestTubeLogger(pyl_log_dir, model_n)
    except:
        tblogger = TensorBoardLogger(pyl_log_dir, model_n)

    # TODO add toggle if user wants to define directory to save
    plots_save_dir = os.path.join(pyl_log_dir, "plots")

    checkpoint_callback = ModelCheckPointWithPlots(
        save_top_k=-1, period=ckpt_period, verbose=True, numb_graphs=20, plot_dataset=True,
        loss_dir=pyl_log_dir, plot_dir=plots_save_dir, mmd=False, lcc=hparams.get("plot_lcc", False)
    )

    print("Starting training")
    trainer = Trainer(
        progress_bar_refresh_rate=1,
        max_epochs=epochs,
        # track_grad_norm=2, works, but needs a small fix to pytorch lightning (should use 0 tensor, not float)
        weights_summary="full",
        logger=tblogger,
        log_save_interval=1,
        checkpoint_callback=checkpoint_callback,
        overfit_pct=overfit_pct,
        num_nodes=1,
        gpus=[int(hyper["device"].split(":")[1])] if "cuda" in hyper["device"] else 0
    )

    with pt.autograd.set_detect_anomaly(detect_anomaly):
        trainer.fit(model)
    fpath = os.path.join("/tmp", f"{uuid4()}")
    trainer.save_checkpoint(fpath)
    _run.add_artifact(fpath, "finalCheckpoint.ckpt")
    _run.add_artifact(fpath)
    tempdir = tempfile.mkdtemp()

    # Plot function
    os.makedirs(plots_save_dir, exist_ok=True)
    model.eval()

    # MMD
    plot_mmd_wrapper(hyper=hyper, current_epoch=trainer.current_epoch, plots_save_dir=plots_save_dir,
                     dataset_name=hyper["dataset"], dataset_used=model.train_set,
                     pyl_log_dir=pyl_log_dir, numb_g_eval=5000, numb_g_mmd=5000,model=model)
