
import matplotlib.pyplot as plt
import numpy as np
from ..ng_utils import compute_infraction


def visualize_hist(samples):
    if not isinstance(samples, np.ndarray):
        samples = samples.cpu().numpy()
    assert samples.ndim == 2 and samples.shape[-1] == 2
    fig, ax = plt.subplots(1, 1, figsize=(5,5))
    ax.hist2d(samples[:, 0], samples[:, 1], range=[[-1.1, 1.1], [-1.1, 1.1]], bins=200)
    for y in np.arange(-1, 1.1, 0.5):
        ax.axhline(y, color="black")
    for x in np.arange(-1, 1.1, 0.5):
        ax.axvline(x, color="black")
    fig.tight_layout()
    return fig


def visualize_scatter(samples, figsize=(5,5)):
    """ Visualizes samples in a scatter plot where infracting and non-infracting samples are colored differently."""
    if not isinstance(samples, np.ndarray):
        samples = samples.cpu().numpy()
    assert samples.ndim == 2 and samples.shape[-1] == 2
    infraction_label = compute_infraction(samples)
    infracting_samples = samples[infraction_label]
    non_infracting_samples = samples[~infraction_label]

    fig, ax = plt.subplots(1, 1, figsize=figsize)

    for y in np.arange(-1, 1.1, 0.5):
        ax.axhline(y, color="black", alpha=0.3)
    for x in np.arange(-1, 1.1, 0.5):
        ax.axvline(x, color="black", alpha=0.3)
    ax.scatter(infracting_samples[:, 0], infracting_samples[:, 1], color="red", s=1, alpha=0.8)
    ax.scatter(non_infracting_samples[:, 0], non_infracting_samples[:, 1], color="green", s=1, alpha=0.8)

    return fig