import os.path as osp

import ANONYMOUS.io as jio
import torch as th
from ANONYMOUS.logging import get_logger
from ANONYMOUS.utils import md5_encode_obj
from ANONYMOUS.utils.env import ANONYMOUS_ABC_getenv
from ANONYMOUStorch.trainer import step_lr
from ANONYMOUStorch.utils import as_numpy, no_grad_func
from torch.optim.lr_scheduler import StepLR

import torchsde
from viz.ps import generate_samples_loss, traj_plot, viz_field, viz_kde, viz_sample

from .fns import dump_info, loss2logz_info, loss2ess_info

logger = get_logger()

# pylint: disable= protected-access
@no_grad_func
def viz_sample_dist(trainer):
    y1, loss, info = generate_samples_loss(
        trainer.mmodel,
        trainer.train_set,
        trainer.dt,
        trainer.t_end,
        trainer.num_particle,
    )
    logz_info = loss2logz_info(loss)
    ess_info = loss2ess_info(loss)
    info.update(as_numpy(logz_info))
    info.update(as_numpy(ess_info))

    trainer.cur_monitor.update(info)
    trainer.latest_result = {
        "samples": as_numpy(y1),
        "loss": as_numpy(loss),
        "epoch_cnt": trainer.epoch_cnt,
        "iter_cnt": trainer.iter_cnt,
    }
    trainer.latest_result.update(info)
    dump_info(trainer)


def check_op_traj_wrapper(cfg):
    ts = th.linspace(0.0, 1.0, cfg.data.viz_traj_ts).cuda()

    @no_grad_func
    def plot_op_traj(trainer):
        model = trainer.model
        n_iter = trainer.iter_cnt
        y0 = th.zeros((cfg.data.viz_traj_num, model.ndim + model.nreg)).cuda()
        ys = torchsde.sdeint(model, y0, ts, dt=1e-3, method="euler")
        traj_plot(
            ys[:, :, : model.ndim],
            xlabel="$x$",
            ylabel="$y$",
            title="traj",
            fsave=f"traj-{n_iter:04d}.png",
        )

    return plot_op_traj


def viz_drift_field(trainer):
    model = trainer.model
    n_iter = trainer.iter_cnt
    viz_field(model, "ps field", f"field-{n_iter:04d}.png")


def trainer_register(trainer, cfg):
    trainer.register_event("val:start", viz_sample_dist)
    trainer.dt = cfg.model.dt
    trainer.t_end = cfg.model.t_end
    trainer.num_particle = cfg.data.num_particle
    f_dist_name = osp.join(
        ANONYMOUS_ABC_getenv("proj_path"), "data", (cfg.data.train_set._target_).split(".")[-1]
    )
    jio.mkdir(f_dist_name)
    run_name = f"{cfg.name}-{cfg.seed:02d}.pkl"
    trainer.fdump = osp.join(f_dist_name, run_name)
    logger.info(f"SAVE PATH: {trainer.fdump}")
    # trainer.register_event("val:start", viz_drift_field)
    # trainer.register_event("val:start", check_op_traj_wrapper(cfg))
    trainer.model.dataset = trainer.train_set

    scheduler = StepLR(trainer.optimizer, step_size=1, gamma=cfg.optimizer.gamma)
    trainer.lr_scheduler = scheduler
    trainer.register_event("val:start", step_lr)
    if cfg.model.enable_temp:
        from .fns import temp_adjust_wrapper

        trainer.register_event("step:start", temp_adjust_wrapper(cfg))
