from pathlib import Path
from typing import Dict, List

import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from compute_result.output_manager.base import OutputManager
from compute_result.typing import Run, BudgetRun
from compute_result.utils import get_color, DEFAULT_COLOR_MAP, get_label_color_map


class ImageOutput(OutputManager):
    def __init__(self, fig: Figure = None):
        self.index = 0
        if fig is None:
            fig, ax = plt.subplots()
        else:
            ax = self.fig.add_axes([0, 0, 1, 1])
        self.fig = fig
        self.ax = ax
        self.graph_name = ""

    def finish(self):
        # self.ax.tick_params(which="major", labelsize=24)
        self.ax.set_xlabel("Budget")
        self.ax.set_ylabel("Value")
        self.fig.legend(fontsize=15)

        # self.ax.set_title("Compare Algorithms")
        self.ax.set_yscale("log")
        self.ax.grid(True, which="both", linestyle="--", color="gray", alpha=0.7)

        img_path = (
            Path(__file__).parent.parent.parent / f"{self.graph_name}_{self.index}.pdf"
        )
        self.ax.figure.savefig(img_path, format="pdf")
        self.fig, self.ax = plt.subplots()
        self.index += 1

    def log_normalize_progression(
        self, results: BudgetRun, plot_name: str, graph_name: str
    ):
        x = [x for x, y in results]
        y = [y for x, y in results]
        plot_name = " ".join(plot_name.split("_"))
        color = get_color(plot_name, DEFAULT_COLOR_MAP)
        self.ax.plot(x, y, label=plot_name, color=color)
        self.ax.grid()
        self.graph_name = graph_name

    def print_image(
        self,
        run: Run,
        fig: Figure,
        index: int = None,
        graph_name: str = "",
        plot_name: str = "",
    ):
        # fig.show()
        path_png = (
            Path(__file__).parent.parent.parent
            / f"{plot_name}_{graph_name}_{index}.pdf"
        )
        fig.savefig(path_png, format="pdf")

    def print_bars(
        self, bars: Dict[str, float], std: List[float], graph_name: str = ""
    ):
        sorted_data = dict(sorted(bars.items(), key=lambda item: item[1], reverse=True))
        algorithms_sorted = list(sorted_data.keys())
        success_rates_sorted = list(sorted_data.values())
        colormap = get_label_color_map(algorithms_sorted, plt.cm.get_cmap("tab10"))
        colors = [colormap[name] for name in algorithms_sorted]

        bars = self.ax.bar(
            algorithms_sorted,
            success_rates_sorted,
            yerr=std,
            error_kw={"elinewidth": 10, "capthick": 10},
            color=colors,
        )
        for bar in bars:
            yval = bar.get_height()
            plt.text(
                bar.get_x() + bar.get_width() / 2,
                yval + 0.01,
                round(yval, 2),
                ha="center",
                va="bottom",
            )

        self.ax.set_xlabel("Algorithms", fontsize=14)
        self.ax.set_xticklabels([])
        self.ax.tick_params(which="major", labelsize=14)
        self.ax.set_ylabel("Success rate", fontsize=14)
        self.ax.legend(bars, algorithms_sorted, loc="lower right")
        self.ax.grid()
        self.fig.tight_layout()

        img_path = (
            Path(__file__).parent.parent.parent / f"{graph_name}_{self.index}.pdf"
        )
        self.ax.figure.savefig(img_path, format="pdf")
