
import torch
import matplotlib.pyplot as plt
import numpy as np


def plot_2D_data_scatter(
    target=None,
    samples=None,
    init_dist=None,
    N_ref=50_000,
    N_plot=10_000,
    xlim=(-2.5, 3.0),
    ylim=(-2.0, 2.5),
    title=None,
    full_plot=True,
    ax=None,
):
    """
    2D scatter plot for target and samples.
    """

    if ax is None:
        fig, ax = plt.subplots(figsize=(4, 4))

    with torch.no_grad():
        if target is not None:
            X_ref = target.sample(N_ref).cpu().numpy()
        if samples is not None:
            X_s = samples[:N_plot].cpu().numpy()
        if init_dist is not None:
            X_init = init_dist[:N_plot].cpu().numpy()

    if target is not None:
        ax.scatter(
            X_ref[:, 0], X_ref[:, 1],
            s=3, alpha=0.25, color="gray",
        )

    if samples is not None:
        ax.scatter(
            X_s[:, 0], X_s[:, 1],
            s=6, alpha=0.6, color="red",
        )

    if init_dist is not None and full_plot:
        ax.scatter(
            X_init[:, 0], X_init[:, 1],
            s=6, alpha=0.6, color="blue",
        )

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_aspect("equal")

    if not full_plot:
        ax.set_xticks([])
        ax.set_yticks([])

    if title is not None:
        ax.set_title(title)

    return ax
