from util.logger import logger

from typing import Optional, Tuple, List

from omegaconf import OmegaConf, DictConfig

from pathlib import Path

import matplotlib.pyplot as plt

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.plot_util import (
    get_line_chart, 
    save_plot
)


def line_chart_scaling_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    # ---------= [Turning Point List Str Dict] =---------
    logger(f"[Turning Point List Str Dict] Loading started. ")

    ddpm_turning_point_list_str = get_true_value(cfg["task"]["turning_point_list_str_dict"]["ddpm"])
    ddim_turning_point_list_str = get_true_value(cfg["task"]["turning_point_list_str_dict"]["ddim"])
    mcts_eps_turning_point_list_str = get_true_value(cfg["task"]["turning_point_list_str_dict"]["mcts_eps"])
    ours_turning_point_list_str = get_true_value(cfg["task"]["turning_point_list_str_dict"]["ours"])

    logger(f"    ddpm_turning_point_list_str: {ddpm_turning_point_list_str}")
    logger(f"    ddim_turning_point_list_str: {ddim_turning_point_list_str}")
    logger(f"    mcts_eps_turning_point_list_str: {mcts_eps_turning_point_list_str}")
    logger(f"    ours_turning_point_list_str: {ours_turning_point_list_str}")

    logger(
        f"[Turning Point List Str Dict] Loading finished. "
        "\n"
    )

    # ---------= [Line Chart] =---------
    logger(f"[Line Chart] Loading started. ")

    figsize = get_true_value(cfg["task"]["line_chart"]["figsize"])

    logger(f"    figsize: {figsize}")

    marker_dict = get_true_value(cfg["task"]["line_chart"]["marker_dict"])

    logger(f"    marker_dict: {marker_dict}")

    label_dict = get_true_value(cfg["task"]["line_chart"]["label_dict"])

    logger(f"    label_dict: {label_dict}")

    color_dict = get_true_value(cfg["task"]["line_chart"]["color_dict"])

    logger(f"    color_dict: {color_dict}")

    xlabel = get_true_value(cfg["task"]["line_chart"]["label"]["xlabel"])
    ylabel = get_true_value(cfg["task"]["line_chart"]["label"]["ylabel"])
    fontsize = get_true_value(cfg["task"]["line_chart"]["label"]["fontsize"])

    logger(f"    xlabel: {xlabel}")
    logger(f"    ylabel: {ylabel}")
    logger(f"    fontsize: {fontsize}")

    line_width = get_true_value(cfg["task"]["line_chart"]["line_width"])

    logger(f"    line_width: {line_width}")

    y_lim = get_true_value(cfg["task"]["line_chart"]["y_lim"])

    logger(f"    y_lim: {y_lim}")

    logger(
        f"[Line Chart] Loading finished. "
        "\n"
    )

    # ---------= [Save Plot] =---------
    logger(f"[Save Plot] Loading started. ")

    save_plot_root_path = get_true_value(cfg["task"]["save_plot"]["save_plot_root_path"])
    save_plot_filename = get_true_value(cfg["task"]["save_plot"]["save_plot_filename"])

    logger(f"    save_plot_root_path: {save_plot_root_path}")
    logger(f"    save_plot_filename: {save_plot_filename}")

    logger(
        f"[Save Plot] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Everything] =---------
    ddpm_turning_point_list = eval(ddpm_turning_point_list_str)
    ddim_turning_point_list = eval(ddim_turning_point_list_str)
    mcts_eps_turning_point_list = eval(mcts_eps_turning_point_list_str)
    ours_turning_point_list = eval(ours_turning_point_list_str)

    figsize = tuple(figsize)

    marker_list = [
        marker_dict["ddpm"], 
        marker_dict["ddim"], 
        marker_dict["mcts_eps"], 
        marker_dict["ours"]
    ]

    label_list = [
        label_dict["ddpm"], 
        label_dict["ddim"], 
        label_dict["mcts_eps"], 
        label_dict["ours"]
    ]

    color_list = [
        color_dict["ddpm"], 
        color_dict["ddim"], 
        color_dict["mcts_eps"], 
        color_dict["ours"] 
    ]

    save_plot_root_path = Path(save_plot_root_path)

    # ---------= [Plot Line Chart] =---------
    fig, ax = plt.subplots(figsize = figsize)

    ax.set_xlabel(
        xlabel = xlabel, 
        fontsize = fontsize
    )
    ax.set_ylabel(
        ylabel = ylabel, 
        fontsize = fontsize
    )

    ax.set_ylim(y_lim[0], y_lim[1])

    def plot_with_turning_point_list(
        idx: int, 
        turning_point_list: List[Tuple[int, float]]
    ):
        x_list = []
        y_list = []

        for (x, y) in turning_point_list:
            x_list.append(x)
            y_list.append(y)

            # goto `for (x, y)`
            pass

        ax.plot(
            x_list, y_list, 

            marker = marker_list[idx], 
            label = label_list[idx], 
            color = color_list[idx], 

            linewidth = line_width
        )

        # `plot_with_turning_point_list()` done
        pass


    plot_with_turning_point_list(
        idx = 0, 
        turning_point_list = ddpm_turning_point_list
    )
    plot_with_turning_point_list(
        idx = 1, 
        turning_point_list = ddim_turning_point_list
    )
    plot_with_turning_point_list(
        idx = 2, 
        turning_point_list = mcts_eps_turning_point_list
    )
    plot_with_turning_point_list(
        idx = 3, 
        turning_point_list = ours_turning_point_list
    )

    ax.legend(
        loc = "lower right", 
        fontsize = fontsize
    )

    ax.grid(True)

    fig.tight_layout()

    save_plot(
        fig = fig, 

        save_plot_root_path = save_plot_root_path, 
        save_plot_filename = save_plot_filename
    )

    # `line_chart_scaling_implement()` done
    pass


def line_chart_scaling(
    cfg: DictConfig
):
    line_chart_scaling_implement(cfg)

    # `line_chart_scaling()` done
    pass
