import numpy as np
from matplotlib import pyplot as plt

from src.evaluation import plot_discount_archive


def visualize_discount_points(discount_train_info, discount_archive, ax,
                              domain_cfg):
    if discount_train_info is None:
        new_features = np.empty((0, 2))
        empty_features = np.empty((0, 2))
        non_empty_features = np.empty((0, 2))
        same_features = np.empty((0, 2))
    else:
        new_features = discount_train_info["new_features"]
        empty_features = discount_train_info["empty_features"]
        non_empty_features = discount_train_info["non_empty_features"]
        same_features = discount_train_info["same_features"]

    plot_discount_archive(discount_archive, ax, domain_cfg)

    ax.scatter(new_features[:, 0],
               new_features[:, 1],
               s=50,
               c="cornflowerblue",
               marker=".",
               label="Emitter Samples",
               alpha=0.7)
    ax.scatter(empty_features[:, 0],
               empty_features[:, 1],
               s=40,
               c="yellow",
               marker="^",
               label="Empty Points",
               alpha=0.7)
    ax.scatter(non_empty_features[:, 0],
               non_empty_features[:, 1],
               s=36,
               c="red",
               marker="^",
               label="Non Empty Points",
               alpha=0.7)
    ax.scatter(same_features[:, 0],
               same_features[:, 1],
               s=40,
               c="lime",
               marker="x",
               label="Same",
               alpha=0.7)

    ax.legend(loc='lower left', bbox_to_anchor=(1.25, 0.), borderaxespad=0.0)
    ax.set_title("Discount Points")


def visualize_discount_points_2(discount_train_info, discount_archive, ax,
                                domain_cfg):
    if discount_train_info is None:
        new_features = np.empty((0, 2))
        empty_features = np.empty((0, 2))
        non_empty_features = np.empty((0, 2))
        same_features = np.empty((0, 2))
    else:
        new_features = discount_train_info["new_features"]
        empty_features = discount_train_info["empty_features"]
        non_empty_features = discount_train_info["non_empty_features"]
        same_features = discount_train_info["same_features"]

    plot_discount_archive(discount_archive, ax, domain_cfg)

    ax.scatter(new_features[:, 0],
               new_features[:, 1],
               s=10,
               c="cornflowerblue",
               marker=".",
               label="Emitter Samples",
               alpha=0.7)
    ax.scatter(empty_features[:, 0],
               empty_features[:, 1],
               s=8,
               c="yellow",
               marker="^",
               label="Empty Points",
               alpha=0.7)

    ax.legend(loc='lower left', bbox_to_anchor=(1.5, 0.), borderaxespad=0.0)
    ax.set_title("Discount Points")
