import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from absl import app, flags
from common_utils.misc import infinite_loader
from common_utils.random import RNG, set_random_seed
from ml_collections.config_flags import config_flags
from tqdm.auto import tqdm
from decimal import Decimal

import gen_neg_toy
from gen_neg_toy import data, dispatch_model_from_path, dispatch_model, script_utils
from gen_neg_toy.ng_utils import compute_infraction, compute_infraction_differentiable
from gen_neg_toy.sde_lib import EDMSDE
from gen_neg_toy.evaluation import visualize_hist, visualize_scatter


FLAGS = flags.FLAGS
flags.DEFINE_string("checkpoint", None, "Checkpoint to load.")
flags.DEFINE_integer("seed", None, "Random seed.")
flags.DEFINE_integer("steps", 100, "Number of steps for ELBO computation.")
flags.mark_flags_as_required(["checkpoint"])


@torch.no_grad()
def plot_classifier(classifier, sigma, device, ax=None):
    N = 200
    x = np.linspace(-3.0, 3.0, N)
    y = np.linspace(-3.0, 3.0, N)

    X, Y = np.meshgrid(x, y)

    XY = np.stack([X.flatten(), Y.flatten()], axis=1) / 2
    Z = torch.cat([classifier(x.to(device), torch.ones(len(x)).to(device) * sigma) for x in torch.as_tensor(XY).float().split(8196)])
    Z = classifier.out_to_p(Z, sigma).float()
    Z = Z.flatten().cpu().numpy()
    dist = compute_infraction_differentiable(XY, norm_p=2)
    blocked = XY[dist > 3 * sigma] * 2
    Z = Z.reshape(N, N)
    Z = Z.reshape(N, N)

    # Automatic selection of levels works; setting the
    # log locator tells contourf to use a log scale:
    ax_was_none = False
    if ax is None:
        ax_was_none = True
        fig, ax = plt.subplots()
    cs = ax.contourf(X, Y, Z, levels=20, cmap="bwr", vmin=0, vmax=1)
    for y in np.arange(-1, 1.1, 0.5) * 2:
        ax.axhline(y, color="black", alpha=0.3)
    for x in np.arange(-1, 1.1, 0.5) * 2:
        ax.axvline(x, color="black", alpha=0.3)

    # Alternatively, you can manually set the levels
    # and the norm:
    # lev_exp = np.arange(np.floor(np.log10(z.min())-1),
    #                    np.ceil(np.log10(z.max())+1))
    # levs = np.power(10, lev_exp)
    # cs = ax.contourf(X, Y, z, levs, norm=colors.LogNorm())
    ax.scatter(blocked[:, 0], blocked[:, 1], color="gray", alpha=0.03)
    if ax_was_none:
        cbar = fig.colorbar(cs)
        return fig


@torch.no_grad()
def main(argv):
    if FLAGS.seed is not None:
        set_random_seed(FLAGS.seed)

    classifier, _ = gen_neg_toy.classifier.dispatch_model_from_path(FLAGS.checkpoint)

    ## Load the dataset ##
    train_set, test_set = data.get_datasets(
        "checkerboard", train_set_size=1000
    )
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=8192, shuffle=True, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        test_set, batch_size=8192, shuffle=False
    )
    #fig = plot_classifier(classifier, 0.2, device="cuda")
    fig, axes = plt.subplots(8, 8, figsize=(8*1.5, 8*1.5))
    axes = axes.flatten()
    for sigma, ax in tqdm(list(zip(np.logspace(np.log(0.002), np.log(80), 100, base=np.e)[:len(axes)], axes))):
        plot_classifier(classifier, sigma, device="cuda", ax=ax)
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        ax.set_xlabel("%.2e" % Decimal(str(sigma)))
    fig.tight_layout()
    fig.savefig("results/classifier.png")


if __name__ == "__main__":
    app.run(main)
