from util.logger import logger

from typing import Optional, Tuple, List

from omegaconf import OmegaConf, DictConfig

from pathlib import Path

import matplotlib.pyplot as plt

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.plot_util import (
    get_line_chart, 
    save_plot
)


def line_chart_scaling_others_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    # ---------= [Turning Point List Str Dict] =---------
    logger(f"[Turning Point List Str Dict] Loading started. ")

    pick_score_turning_point_list_str_dict = get_true_value(cfg["task"]["turning_point_list_str_dict"]["pick_score"])
    image_reward_turning_point_list_str_dict = get_true_value(cfg["task"]["turning_point_list_str_dict"]["image_reward"])

    logger(f"    pick_score_turning_point_list_str_dict: {pick_score_turning_point_list_str_dict}")
    logger(f"    image_reward_turning_point_list_str_dict: {image_reward_turning_point_list_str_dict}")

    logger(
        f"[Turning Point List Str Dict] Loading finished. "
        "\n"
    )

    # ---------= [Line Chart] =---------
    logger(f"[Line Chart] Loading started. ")

    figsize = get_true_value(cfg["task"]["line_chart"]["figsize"])

    logger(f"    figsize: {figsize}")

    marker_dict = get_true_value(cfg["task"]["line_chart"]["marker_dict"])

    logger(f"    marker_dict: {marker_dict}")

    label_dict = get_true_value(cfg["task"]["line_chart"]["label_dict"])

    logger(f"    label_dict: {label_dict}")

    color_dict = get_true_value(cfg["task"]["line_chart"]["color_dict"])

    logger(f"    color_dict: {color_dict}")

    line_style_dict = get_true_value(cfg["task"]["line_chart"]["line_style_dict"])

    logger(f"    line_style_dict: {line_style_dict}")

    xlabel = get_true_value(cfg["task"]["line_chart"]["label"]["xlabel"])
    ylabel_dict = get_true_value(cfg["task"]["line_chart"]["label"]["ylabel_dict"])
    fontsize = get_true_value(cfg["task"]["line_chart"]["label"]["fontsize"])

    logger(f"    xlabel: {xlabel}")
    logger(f"    ylabel_dict: {ylabel_dict}")
    logger(f"    fontsize: {fontsize}")

    line_width = get_true_value(cfg["task"]["line_chart"]["line_width"])

    logger(f"    line_width: {line_width}")

    y_lim_dict = get_true_value(cfg["task"]["line_chart"]["y_lim_dict"])

    logger(f"    y_lim_dict: {y_lim_dict}")

    logger(
        f"[Line Chart] Loading finished. "
        "\n"
    )

    # ---------= [Save Plot] =---------
    logger(f"[Save Plot] Loading started. ")

    save_plot_root_path = get_true_value(cfg["task"]["save_plot"]["save_plot_root_path"])
    save_plot_filename = get_true_value(cfg["task"]["save_plot"]["save_plot_filename"])

    logger(f"    save_plot_root_path: {save_plot_root_path}")
    logger(f"    save_plot_filename: {save_plot_filename}")

    logger(
        f"[Save Plot] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Everything] =---------
    figsize = tuple(figsize)

    marker_list = [
        marker_dict["ddpm"], 
        marker_dict["ddim"], 
        marker_dict["mcts_eps"], 
        marker_dict["ours"], 
        marker_dict["ddpm"], 
        marker_dict["ddim"], 
        marker_dict["mcts_eps"], 
        marker_dict["ours"], 
    ]

    label_list = [
        label_dict["ddpm"], 
        label_dict["ddim"], 
        label_dict["mcts_eps"], 
        label_dict["ours"], 
        label_dict["ddpm"], 
        label_dict["ddim"], 
        label_dict["mcts_eps"], 
        label_dict["ours"], 
    ]

    color_list = [
        color_dict["ddpm"], 
        color_dict["ddim"], 
        color_dict["mcts_eps"], 
        color_dict["ours"], 
        color_dict["ddpm"], 
        color_dict["ddim"], 
        color_dict["mcts_eps"], 
        color_dict["ours"], 
    ]

    line_style_list = [
        line_style_dict["pick_score"], 
        line_style_dict["pick_score"], 
        line_style_dict["pick_score"], 
        line_style_dict["pick_score"], 
        line_style_dict["image_reward"], 
        line_style_dict["image_reward"], 
        line_style_dict["image_reward"], 
        line_style_dict["image_reward"]
    ]

    save_plot_root_path = Path(save_plot_root_path)

    # ---------= [Plot Line Chart] =---------
    fig, ax_pick_score = plt.subplots(figsize = figsize)
    ax_image_reward = ax_pick_score.twinx()

    ax_pick_score.set_xlabel(
        xlabel = xlabel, 
        fontsize = fontsize
    )

    ax_pick_score.set_ylabel(
        ylabel = ylabel_dict["pick_score"], 
        fontsize = fontsize
    )
    ax_image_reward.set_ylabel(
        ylabel = ylabel_dict["image_reward"], 
        fontsize = fontsize
    )

    y_lim_pick_score = y_lim_dict["pick_score"]
    ax_pick_score.set_ylim(y_lim_pick_score[0], y_lim_pick_score[1])
    y_lim_image_reward = y_lim_dict["image_reward"]
    ax_image_reward.set_ylim(y_lim_image_reward[0], y_lim_image_reward[1])
    

    def plot_with_turning_point_list(
        ax, 
        idx: int, 
        turning_point_list: List[Tuple[int, float]]
    ):
        x_list = []
        y_list = []

        for (x, y) in turning_point_list:
            x_list.append(x)
            y_list.append(y)

            # goto `for (x, y)`
            pass

        ax.plot(
            x_list, y_list, 

            marker = marker_list[idx], 
            label = label_list[idx], 
            color = color_list[idx], 
            linestyle = line_style_list[idx], 

            linewidth = line_width
        )

        # `plot_with_turning_point_list()` done
        pass


    plot_with_turning_point_list(
        ax = ax_pick_score, 
        idx = 0, 
        turning_point_list = eval(pick_score_turning_point_list_str_dict["ddpm"])
    )
    plot_with_turning_point_list(
        ax = ax_pick_score, 
        idx = 1, 
        turning_point_list = eval(pick_score_turning_point_list_str_dict["ddim"])
    )
    plot_with_turning_point_list(
        ax = ax_pick_score, 
        idx = 2, 
        turning_point_list = eval(pick_score_turning_point_list_str_dict["mcts_eps"])
    )
    plot_with_turning_point_list(
        ax = ax_pick_score, 
        idx = 3, 
        turning_point_list = eval(pick_score_turning_point_list_str_dict["ours"])
    )

    plot_with_turning_point_list(
        ax = ax_image_reward, 
        idx = 4, 
        turning_point_list = eval(image_reward_turning_point_list_str_dict["ddpm"])
    )
    plot_with_turning_point_list(
        ax = ax_image_reward, 
        idx = 5, 
        turning_point_list = eval(image_reward_turning_point_list_str_dict["ddim"])
    )
    plot_with_turning_point_list(
        ax = ax_image_reward, 
        idx = 6, 
        turning_point_list = eval(image_reward_turning_point_list_str_dict["mcts_eps"])
    )
    plot_with_turning_point_list(
        ax = ax_image_reward, 
        idx = 7, 
        turning_point_list = eval(image_reward_turning_point_list_str_dict["ours"])
    )

    line_list_pick_score, label_list_pick_score = ax_pick_score.get_legend_handles_labels()
    line_list_image_reward, label_list_image_reward = ax_image_reward.get_legend_handles_labels()

    ax_pick_score.legend(
        line_list_pick_score + line_list_image_reward, 
        label_list_pick_score + label_list_image_reward, 

        loc = "lower right", 
        ncol = 2, 

        fontsize = fontsize
    )

    ax_pick_score.grid(True)

    fig.tight_layout()

    save_plot(
        fig = fig, 

        save_plot_root_path = save_plot_root_path, 
        save_plot_filename = save_plot_filename
    )

    # `line_chart_scaling_others_implement()` done
    pass


def line_chart_scaling_others(
    cfg: DictConfig
):
    line_chart_scaling_others_implement(cfg)

    # `line_chart_scaling_others()` done
    pass
