import os.path
import pprint as pp

import matplotlib.pyplot as plt
import numpy as np

from active_ranking import utils
from active_ranking.base.model import MessyRank
from active_ranking.base.utils import FunctionAsLabeler

if __name__ == '__main__':
    import matplotlib
    from active_ranking.scenarios import inputs

    matplotlib.use("qt5agg")

    scenario_name = "scenario_8"

    scenario = inputs.__dict__[scenario_name]
    j_max = scenario["j_max"]
    d = scenario["d"]
    e_ = scenario["eta"]
    n_max = scenario["n_max"]
    n_0 = scenario["n_0"]

    learner = MessyRank(j_max=j_max, d=d, eta=FunctionAsLabeler(e_),
                        alternative_p_estimate=True)
    path = f"results/figures/analysis/{scenario_name}/{learner.name}"

    if not os.path.exists(path):
        os.makedirs(path)

    learner.add_tracker(e_)
    self = learner
    # learner.run_and_plot(n_0=n_0, n_max=n_max, step_mod=100)
    learner.run(n_0=n_0, n_max=n_max)
    print(utils.__cached_time__)
    print(utils.__cached_time_n_call__)
    learner.plot_ucb_lcb(partition_type="p_cells", plot_active_cell=True)
    plt.tight_layout()
    plt.savefig(f"{path}/ucb_lcb_p_cells.png")

    # ASSESS MODEL

    y_prediction = np.array(list(learner.predictions.values())[-1])
    y_true = learner.y_test
    y_roc_true = np.array(learner.true_eta(learner.x_test))

    from active_ranking.experiments import plotting

    plotting.plot_true_roc_curves(y_true, y_prediction, y_roc_true)
    plt.savefig(f"{path}/roc_curve.png")

    plotting.plot_roc_epochs(learner)
    plt.savefig(f"{path}/roc_curve_t.png")

    plotting.plot_d1_norm(learner)
    plt.savefig(f"{path}/regret_d1.png")

    plotting.plot_d_inf_norm(learner)
    plt.savefig(f"{path}/regret_dinfty.png")

    plotting.plot_swarm_sampling(learner, animate=False)
    plt.savefig(f"{path}/sampling")

    # animation = plotting.plot_swarm_sampling(learner, animate=True)
    # writergif = PillowWriter(fps=300)
    # animation.save(f"{path}/sampling.gif", writer=writergif)

    # TEST
    pp.pprint(self.partition.p_cells)

    # +++++++++++++++++++++++++++
    # explore messy rank
    # +++++++++++++++++++++++++++
