from util.logger import logger

from typing import Optional

from omegaconf import OmegaConf, DictConfig

from pathlib import Path

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.plot_util import (
    save_plot, 
    get_scatter
)


def scatter_mpd_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    # ---------= [Search Style] =---------
    logger(f"[Search Style] Loading started. ")

    no_eta_eps_dict = get_true_value(cfg["task"]["search_style"]["init_noise_style"]["no_eta_eps"])
    fixed_eps_fixed_eta_dict = get_true_value(cfg["task"]["search_style"]["init_noise_style"]["fixed_eps-fixed_eta"])
    fixed_eps_random_eta_dict = get_true_value(cfg["task"]["search_style"]["init_noise_style"]["fixed_eps-random_eta"])

    logger(f"    no_eta_eps_dict: {no_eta_eps_dict}")
    logger(f"    fixed_eps_fixed_eta_dict: {fixed_eps_fixed_eta_dict}")
    logger(f"    fixed_eps_random_eta_dict: {fixed_eps_random_eta_dict}")

    fixed_init_noise_fixed_eta_dict = get_true_value(cfg["task"]["search_style"]["eps_style"]["fixed_init_noise-fixed_eta"])

    logger(f"    fixed_init_noise_fixed_eta_dict: {fixed_init_noise_fixed_eta_dict}")

    fixed_init_noise_fixed_eps_dict = get_true_value(cfg["task"]["search_style"]["eta_style"]["fixed_init_noise-fixed_eps"])

    logger(f"    fixed_init_noise_fixed_eps_dict: {fixed_init_noise_fixed_eps_dict}")

    logger(
        f"[Search Style] Loading finished. "
        "\n"
    )

    # ---------= [Scatter] =---------
    logger(f"[Scatter] Loading started. ")

    figsize = get_true_value(cfg["task"]["scatter"]["figsize"])

    logger(f"    figsize: {figsize}")

    marker_dict = get_true_value(cfg["task"]["scatter"]["marker_dict"])

    logger(f"    marker_dict: {marker_dict}")

    color_dict = get_true_value(cfg["task"]["scatter"]["color_dict"])

    logger(f"    color_dict: {color_dict}")

    logger(
        f"[Scatter] Loading finished. "
        "\n"
    )

    # ---------= [Save Plot] =---------
    logger(f"[Save Plot] Loading started. ")

    save_plot_root_path = get_true_value(cfg["task"]["save_plot"]["save_plot_root_path"])

    logger(f"    save_plot_root_path: {save_plot_root_path}")

    logger(
        f"[Save Plot] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Save Plot Root Path] =---------
    save_plot_root_path = Path(save_plot_root_path)

    # ---------= [Prepare Scatter] =---------
    figsize = tuple(figsize)

    sd_type_list = ["sd-turbo", "sd"]
    style_list = [
        "init_noise_style", 
        "fixed_eps-fixed_eta", 
        "fixed_eps-random_eta", 

        "fixed_init_noise-fixed_eta", 

        "fixed_init_noise-fixed_eps"
    ]

    label_dict = {
        "init_noise_style": "Noise-0", 
        "fixed_eps-fixed_eta": "Noise-1", 
        "fixed_eps-random_eta": "Noise-2", 

        "fixed_init_noise-fixed_eta": "Eps", 

        "fixed_init_noise-fixed_eps": "Eta"
    }

    num_inference_step_dict = {
        "sd-turbo": 2, 
        "sd": 20
    }

    point_list = []
    label_list = []
    marker_list = []
    color_list = []

    for style in style_list:
        style_dict_var_name = style.replace('-', '_')
        style_dict_var_name = f"{style_dict_var_name}_dict"

        style_dict = locals()[style_dict_var_name]

        for sd_type in sd_type_list:
            style_dict_path = style_dict[sd_type]
            tmp_style_dict = load_yaml(style_dict_path)

            mpd_prompt = tmp_style_dict["mpd_prompt"]
            mpd_all = tmp_style_dict["mpd_all"]

            point_list.append(
                (mpd_prompt, mpd_all)
            )

            label = f"{label_dict[style]} ({num_inference_step_dict[sd_type]})"
            label_list.append(label)

            marker_list.append(marker_dict[style])

            color_list.append(color_dict[sd_type])

            # goto `for sd_type`
            pass

        # goto `for style`
        pass

    # ---------= [Plot Scatter] =---------
    plot_title = None
    plot_x_label = "MPD (prompt)"
    plot_y_label = "MPD (all)"

    show_grid = True
    
    show_legend = True
    legend_num_col = 2

    fig, _ = get_scatter(
        figsize = figsize, 

        point_list = point_list, 

        label_list = label_list, 

        marker_list = marker_list, 

        color_list = color_list, 

        plot_title = plot_title, 
        plot_x_label = plot_x_label, 
        plot_y_label = plot_y_label, 

        show_grid = show_grid, 

        show_legend = show_legend, 
        legend_num_col = legend_num_col
    )

    save_plot(
        fig = fig, 

        save_plot_root_path = save_plot_root_path
        save_plot_filename = f"mpd_lpips.png"
    )

    # `scatter_mpd_implement()` done
    pass


def scatter_mpd(
    cfg: DictConfig
):
    scatter_mpd_implement(cfg)

    # `scatter_mpd()` done
    pass
