import matplotlib.pyplot as plt
import numpy as np
import torch as th
from ANONYMOUS.logging import wandb_img
from ANONYMOUStorch.trainer import step_lr
from ANONYMOUStorch.utils import as_numpy, no_grad_func
from torch.optim.lr_scheduler import StepLR

import torchsde
from viz.ou import generate_samples, traj_plot


# pylint: disable=invalid-name
def check_op_traj_wrapper(cfg):
    ts = th.linspace(0.0, cfg.model.t_end, cfg.data.viz_traj_ts).cuda()

    @no_grad_func
    def plot_op_traj(trainer):
        model = trainer.model
        n_iter = trainer.iter_cnt
        y0 = th.zeros((cfg.data.viz_traj_num, model.ndim + model.nreg)).cuda()
        ys = torchsde.sdeint(model, y0, ts, dt=0.01)
        traj_plot(
            ts,
            ys[:, :, : model.ndim],
            xlabel="$t$",
            ylabel="$Y_t$",
            title=f"Iter {n_iter:04d}",
            fsave=f"{n_iter:04d}.png",
        )
        wandb_img("traj", f"{n_iter:04d}.png", n_iter)

    return plot_op_traj


def viz_sample_dist(trainer):
    y1 = generate_samples(trainer.mmodel)
    n_iter = trainer.iter_cnt
    fig, ax = plt.subplots(1, 1, figsize=(7, 7))
    density, bins = np.histogram(y1, 100, density=True)
    query_x = th.linspace(-4.5, 4.5, 100).cuda()
    query_pdf = trainer.train_set.pdf(query_x)
    query_eneryg_unpdf = trainer.train_set.energy_unpdf(query_x)
    query_energy_epdf = query_eneryg_unpdf / th.sum(query_eneryg_unpdf) / 9 * 100
    ax.plot(bins[1:], density, label="sampled")
    np_x = as_numpy(query_x)
    np_gt_p, np_e_p = as_numpy([query_pdf, query_energy_epdf])
    ax.plot(np_x, np_gt_p, label="gt")
    ax.plot(np_x, np_e_p, label="energy_pdf")
    ax.set_xlim(np_x[0], np_x[-1])
    ax.set_ylim(0, 1.5 * np.max(np_gt_p))
    ax.legend()
    fig.savefig(f"y-{n_iter:04d}.png")
    plt.close(fig)
    wandb_img("y1", f"y-{n_iter:04d}.png", n_iter)


def trainer_register(trainer, cfg):
    trainer.register_event("val:start", check_op_traj_wrapper(cfg))
    trainer.register_event("val:start", viz_sample_dist)
    if cfg.model.enable_temp:
        from .fns import temp_adjust_wrapper

        trainer.register_event("step:start", temp_adjust_wrapper(cfg))

    scheduler = StepLR(trainer.optimizer, step_size=1, gamma=cfg.optimizer.gamma)
    trainer.lr_scheduler = scheduler
    trainer.register_event("val:start", step_lr)
