import os.path
import pprint as pp

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from active_ranking.experiments import experiment
from active_ranking.experiments import plotting

matplotlib.use("agg")
plt.ioff()

if __name__ == '__main__':

    from active_ranking.scenarios import inputs

    for k in inputs.__dict__.keys():
        if "scenario" in k:
            scenario_name = str(k)
            scenario = inputs.__dict__[scenario_name]
            j_max = scenario["j_max"]
            d = scenario["d"]
            e_ = scenario["eta"]
            n_max = scenario["n_max"]
            n_0 = scenario["n_0"]

            for model in experiment.models:

                learner = model(n_0, n_max, j_max, d, e_)
                path = f"results/figures/analysis/{scenario_name}/{learner.name}/"

                if not os.path.exists(path):
                    os.makedirs(path)
                self = learner
                learner.plot_ucb_lcb(partition_type="p_cells",
                                     plot_active_cell=True)
                plt.title("$P$-cells")
                plt.tight_layout()
                plt.savefig(f"{path}/ucb_lcb_p_cells.png")
                # ASSESS MODEL

                y_prediction = np.array(list(learner.predictions.values())[-1])
                y_true = learner.y_test
                y_roc_true = np.array(learner.true_eta(learner.x_test))

                plotting.plot_true_roc_curves(y_true, y_prediction, y_roc_true)
                plt.savefig(f"{path}/roc_curve.png")

                plotting.plot_roc_epochs(learner)
                plt.savefig(f"{path}/roc_curve_t.png")

                plotting.plot_d1_norm(learner)
                plt.savefig(f"{path}/regret_d1.png")

                plotting.plot_d_inf_norm(learner)
                plt.savefig(f"{path}/regret_dinfty.png")

                plotting.plot_swarm_sampling(learner, animate=False)
                plt.savefig(f"{path}/sampling")

                if scenario == "scenario_1":
                    from matplotlib.animation import PillowWriter

                    animation = plotting.plot_swarm_sampling(learner,
                                                             animate=True)
                    writergif = PillowWriter(fps=300)
                    animation.save(f"{path}/sampling.gif", writer=writergif)

                # TEST
                pp.pprint(self.partition.p_cells)
