import numpy as np
from dpdt import dpdt
from dpdt.utils import CartAIGSelector
from dpdt.utils.feature_selectors import KBest
from data import *
import matplotlib.pyplot as plt

states, acc = [], []
get_iris_data()

S, Y = get_iris_data()

# exh
scores_dp, depths_dp, nodes_dp, time_, state = dpdt(
    S,
    Y,
    max_depth=3,
    aig_fn=KBest(1000),
    zetas=np.zeros(1),
    save_policy=False,
    verbose=False,
    count_states=True,
)
states.append(state)
acc.append(scores_dp)
scores_dp, depths_dp, nodes_dp, time_, state = dpdt(
    S,
    Y,
    max_depth=3,
    aig_fn=KBest(55),
    zetas=np.zeros(1),
    save_policy=False,
    verbose=False,
    count_states=True,
)
states.append(state)
acc.append(scores_dp)

scores_dp, depths_dp, nodes_dp, time_, state = dpdt(
    S,
    Y,
    max_depth=3,
    aig_fn=KBest(10),
    zetas=np.zeros(1),
    save_policy=False,
    verbose=False,
    count_states=True,
)
states.append(state)
acc.append(scores_dp)

scores_dp, depths_dp, nodes_dp, time_, state = dpdt(
    S,
    Y,
    max_depth=3,
    aig_fn=KBest(5),
    zetas=np.zeros(1),
    save_policy=False,
    verbose=False,
    count_states=True,
)
states.append(state)
acc.append(scores_dp)

scores_dp, depths_dp, nodes_dp, time_, state = dpdt(
    S,
    Y,
    max_depth=3,
    aig_fn=KBest(1),
    zetas=np.zeros(1),
    save_policy=False,
    verbose=False,
    count_states=True,
)
states.append(state)
acc.append(scores_dp)

scores_dp, depths_dp, nodes_dp, time_, state = dpdt(
    S,
    Y,
    max_depth=3,
    aig_fn=CartAIGSelector(2),
    zetas=np.zeros(1),
    save_policy=False,
    verbose=False,
    count_states=True,
)
states.append(state)
acc.append(scores_dp)

scores_dp, depths_dp, nodes_dp, time_, state = dpdt(
    S,
    Y,
    max_depth=3,
    aig_fn=CartAIGSelector(3),
    zetas=np.zeros(1),
    save_policy=False,
    verbose=False,
    count_states=True,
)
states.append(state)
acc.append(scores_dp)

scores_dp, depths_dp, nodes_dp, time_, state = dpdt(
    S,
    Y,
    max_depth=3,
    aig_fn=CartAIGSelector(4),
    zetas=np.zeros(1),
    save_policy=False,
    verbose=False,
    count_states=True,
)
states.append(state)
acc.append(scores_dp)
# np.save(
#     "results_npz/{}_cart5_selector_depth3".format(dataset), [nodes_dp, scores_dp]
# )
# np.save("results_npz/time_{}_cart5_selector_depth3".format(dataset), [time_])


names = [
    "Exhaustive",
    "TOP 55",
    "TOP 10",
    "TOP 5",
    "CART (TOP 1)",
    "DPDT-2",
    "DPDT-3",
    "DPDT-4",
]

for i, s in enumerate(states[:-3]):
    if i == 4:
        plt.scatter(s, acc[i], marker="P", s=100, label=names[i])
    elif i == 0:

        plt.scatter(s, acc[i], marker="d", s=100, label=names[i])
    else:
        plt.scatter(s, acc[i], marker="x", s=100, label=names[i])


for i, s in enumerate(states[-3:]):
    plt.scatter(s, acc[i + 5], marker="o", s=100, label=names[i + 5])
plt.legend(loc="center right")
plt.grid()
plt.xlabel("Number of MDP states", fontdict={"size": 14})
plt.ylabel("Accuracy", fontdict={"size": 14})
plt.tick_params(axis="x", labelsize=14)
plt.tick_params(axis="y", labelsize=12)
plt.xscale("log")
plt.savefig("plots/selectors.pdf")
