import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from absl import app, flags
from gen_neg_toy.utils import infinite_loader
from gen_neg_toy.utils.random import set_random_seed
from ml_collections.config_flags import config_flags
from tqdm.auto import tqdm

import gen_neg_toy
from gen_neg_toy import (
    data,
    dispatch_model_from_path,
    dispatch_model,
    script_utils,
)
from gen_neg_toy.ng_utils import (
    compute_infraction,
    compute_infraction_differentiable,
)
from gen_neg_toy.sde_lib import EDMSDE
from gen_neg_toy.evaluation import visualize_hist, visualize_scatter


FLAGS = flags.FLAGS
flags.DEFINE_string("checkpoint", None, "Checkpoint to load.")
flags.DEFINE_list("classifier", [], "List of lassifier paths.")
flags.DEFINE_string("bridge_scale_schedule", None, "Guidance scale schedule.")
flags.DEFINE_integer("seed", None, "Random seed.")
flags.DEFINE_integer("steps", 100, "Number of steps for ELBO computation.")
flags.DEFINE_integer("rho", 7, "rho parameter of the sampling procedure.")
flags.DEFINE_float("S_churn", 10, "S_chrun parameter of the sampling procedure.")
flags.DEFINE_integer("n_trials", 20, "Number of trials to compute mean and std of metrics.")
flags.mark_flags_as_required(["checkpoint"])


@torch.no_grad()
def main(argv):
    if FLAGS.seed is not None:
        set_random_seed(FLAGS.seed)

    model, config = dispatch_model_from_path(
        FLAGS.checkpoint,
        strict=(FLAGS.classifier is None),
        classifier=FLAGS.classifier,
        bridge_scale_schedule=FLAGS.bridge_scale_schedule,
    )

    ## Load the dataset ##
    train_set, test_set = data.get_datasets(config.data)
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=config.training.batch_size, shuffle=True, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        test_set, batch_size=config.training.batch_size, shuffle=False
    )
    inf_train_loader = infinite_loader(train_loader)

    ## Log some info ##
    print(f"Number of model parameters: {model.count_parameters():,}")

    ## ELBO computation ##
    print(f"Number of steps for ELBO and sample generation: {FLAGS.steps}")
    elbos = np.stack(
        [
            script_utils.elbo_from_dataloader(
                model, val_loader, device=config.device, num_steps=100
            )
            for _ in range(FLAGS.n_trials)
        ]
    )
    elbo_mean = elbos.mean()
    elbo_std = elbos.std()
    print(f"ELBO = {elbo_mean} +/- {elbo_std}")

    ## Infraction computation ##
    # samples, nfe = script_utils.draw_samples(model, n_samples=1000, device=config.device, num_steps=FLAGS.steps, rho=FLAGS.rho)
    # print(f"Infraction = {compute_infraction(samples).float().mean().item() * 100} %")
    infractions = []
    infraction_dists = []
    for _ in range(FLAGS.n_trials):
        samples, nfe = script_utils.draw_samples(
            model,
            n_samples=20000,
            device=config.device,
            S_churn=FLAGS.S_churn,
            num_steps=FLAGS.steps,
            rho=FLAGS.rho,
            verbose=True,
        )
        infraction = compute_infraction(samples).cpu().numpy()
        infractions.append(infraction.astype("float").mean() * 100)
        infraction_dist = (
            compute_infraction_differentiable(samples, norm_p=1).cpu().numpy()
        )
        if infraction.astype("float").sum() > 0:
            infraction_dist = infraction_dist[infraction]
            infraction_dists.append(infraction_dist)
    infraction_dists = np.concatenate(infraction_dists)
    print(f"Infraction = {np.mean(infractions)} +- {np.std(infractions)} %")
    print(f"Infraction dist: {infraction_dists.mean()}, {infraction_dists.max()}")

    visualize_hist(samples).savefig("results/hist.png")
    visualize_scatter(samples).savefig("results/scatter.png")


if __name__ == "__main__":
    app.run(main)
