from util.logger import logger

from typing import Optional, Tuple, Union, Dict, List, Set, Callable, TypeVar, Generic, NewType, Protocol

import hydra
from omegaconf import DictConfig, OmegaConf
import yaml

from util.basic_util import (
    pause, 
    set_global_variable_dict, 
    get_global_variable, set_global_variable
)


def test(
    cfg: DictConfig
):
    from task.test import test
    test(cfg)

    # `test()` done
    pass


def generate_init_latent(
    cfg: DictConfig
):
    from task.generate_init_latent import generate_init_latent
    generate_init_latent(cfg)

    # `generate_init_latent()` done
    pass


def add_optimized_prompt(
    cfg: DictConfig
):
    from task.add_optimized_prompt import add_optimized_prompt
    add_optimized_prompt(cfg)

    # `add_optimized_prompt()` done
    pass


def sample(
    cfg: DictConfig
):
    task_name = cfg["task"]["name"]
    
    if task_name.startswith("sample-t2i-sample_scheduled"):
        from task.sample.t2i.sample_scheduled import sample_scheduled
        sample_scheduled(cfg)
    elif task_name.startswith("sample-t2i-run_sample_scheduled"):
        from task.sample.t2i.run_sample_scheduled import run_sample_scheduled
        run_sample_scheduled(cfg)

    else:
        raise NotImplementedError(
            f"Unsupported task: `{task_name}`. "
        )

    # `sample()` done
    pass


def baseline(
    cfg: DictConfig
):
    task_name = cfg["task"]["name"]
    
    if task_name.startswith("baseline-run_z_sampling"):
        from task.baseline.run_z_sampling.run import run_z_sampling
        run_z_sampling(cfg)

    else:
        raise NotImplementedError(
            f"Unsupported task: `{task_name}`. "
        )

    # `sample()` done
    pass


def cal_final_reward(
    cfg: DictConfig
):
    task_name = cfg["task"]["name"]

    if task_name.startswith("cal_metric-cal_final_reward_baseline"):
        from task.cal_metric.cal_final_reward_baseline import cal_final_reward_baseline
        cal_final_reward_baseline(cfg)
    elif task_name.startswith("cal_metric-cal_final_reward_ours"):
        from task.cal_metric.cal_final_reward_ours import cal_final_reward_ours
        cal_final_reward_ours(cfg)

    else:
        raise NotImplementedError(
            f"Unsupported task: `{task_name}`. "
        )
    
    # `cal_final_reward()` done
    pass


def cal_metric(
    cfg: DictConfig
):
    task_name = cfg["task"]["name"]
    
    if task_name.startswith("cal_metric-cal_mean_pairwise_distance"):
        from task.cal_metric.cal_mean_pairwise_distance import cal_mean_pairwise_distance
        cal_mean_pairwise_distance(cfg)
    elif task_name.startswith("cal_metric-cal_final_reward"):
        cal_final_reward(cfg)
    elif task_name.startswith("cal_metric-cal_reward_robustness"):
        from task.cal_metric.cal_reward_robustness import cal_reward_robustness
        cal_reward_robustness(cfg)
    elif task_name.startswith("cal_metric-cal_mscoco_fid"):
        from task.cal_metric.cal_mscoco_fid import cal_mscoco_fid
        cal_mscoco_fid(cfg)

    else:
        raise NotImplementedError(
            f"Unsupported task: `{task_name}`. "
        )

    # `sample()` done
    pass


def search(
    cfg: DictConfig
):
    task_name = cfg["task"]["name"]
    
    # if task_name.startswith("search-optimal_control_mcts"):
    #     from task.search._optimal_control_mcts import optimal_control_mcts
    #     optimal_control_mcts(cfg)
    
    if task_name.startswith("search-run_optimal_control_mcts") \
        or task_name.startswith("search-run_optimal_control_mcts_eps"):
        from task.search.run_optimal_control_mcts import run_optimal_control_mcts
        run_optimal_control_mcts(cfg)
    elif task_name.startswith("search-run_optimal_control_bs"):
        from task.search.run_optimal_control_bs import run_optimal_control_bs
        run_optimal_control_bs(cfg)

    else:
        raise NotImplementedError(
            f"Unsupported task: `{task_name}`. "
        )

    # `sample()` done
    pass


def display_result(
    cfg: DictConfig
):
    task_name = cfg["task"]["name"]
    
    if task_name.startswith("display_result-display_result_ours"):
        from task.display_result.display_result_ours import display_result_ours
        display_result_ours(cfg)
    elif task_name.startswith("display_result-get_scaling_list"):
        from task.display_result.get_scaling_list import get_scaling_list
        get_scaling_list(cfg)
    elif task_name.startswith("display_result-display_result_baseline"):
        from task.display_result.display_result_baseline import display_result_baseline
        display_result_baseline(cfg)
    elif task_name.startswith("display_result-display_result_z_sampling"):
        from task.display_result.display_result_z_sampling import display_result_z_sampling
        display_result_z_sampling(cfg)
    elif task_name.startswith("display_result-get_baseline_scaling_line_chart_str"):
        from task.display_result.get_baseline_scaling_line_chart_str import get_baseline_scaling_line_chart_str
        get_baseline_scaling_line_chart_str(cfg)

    else:
        raise NotImplementedError(
            f"Unsupported task: `{task_name}`. "
        )

    # `sample()` done
    pass


def collect_result(
    cfg: DictConfig
):
    task_name = cfg["task"]["name"]
    
    if task_name.startswith("collect_result-collect_result_baseline"):
        from task.collect_result.collect_result_baseline import collect_result_baseline
        collect_result_baseline(cfg)
    elif task_name.startswith("collect_result-collect_result_z_sampling"):
        from task.collect_result.collect_result_z_sampling import collect_result_z_sampling
        collect_result_z_sampling(cfg)
    elif task_name.startswith("collect_result-collect_result_scaling_z_sampling"):
        from task.collect_result.collect_result_scaling_z_sampling import collect_result_scaling_z_sampling
        collect_result_scaling_z_sampling(cfg)
    elif task_name.startswith("collect_result-collect_result_ours"):
        from task.collect_result.collect_result_ours import collect_result_ours
        collect_result_ours(cfg)

    else:
        raise NotImplementedError(
            f"Unsupported task: `{task_name}`. "
        )

    # `collect_result()` done
    pass


def plot(
    cfg: DictConfig
):
    task_name = cfg["task"]["name"]
    
    if task_name.startswith("plot-scatter_mpd"):
        from task.plot.scatter_mpd import scatter_mpd
        scatter_mpd(cfg)
    elif task_name.startswith("plot-line_chart_scaling_others"):
        from task.plot.line_chart_scaling_others import line_chart_scaling_others
        line_chart_scaling_others(cfg)
    elif task_name.startswith("plot-line_chart_scaling"):
        from task.plot.line_chart_scaling import line_chart_scaling
        line_chart_scaling(cfg)
    elif task_name.startswith("plot-line_chart_zeta"):
        from task.plot.line_chart_zeta import line_chart_zeta
        line_chart_zeta(cfg)

    else:
        raise NotImplementedError(
            f"Unsupported task: `{task_name}`. "
        )

    # `sample()` done
    pass


# def do_ddim_inversion(
#     cfg: DictConfig
# ):
#     task_name = cfg["task"]["name"]
    
#     if task_name.startswith("do_ddim_inversion"):
#         from task.do_ddim_inversion.do_ddim_inversion import do_ddim_inversion
#         do_ddim_inversion(cfg)

#     else:
#         raise NotImplementedError(
#             f"Unsupported task: `{task_name}`. "
#         )

#     # `sample()` done
#     pass


def run_task(
    cfg: DictConfig
):
    task_name = cfg["task"]["name"]
    
    if task_name.startswith("generate_init_latent"):
        generate_init_latent(cfg)
    elif task_name.startswith("add_optimized_prompt"):
        add_optimized_prompt(cfg)
    elif task_name.startswith("sample"):
        sample(cfg)
    elif task_name.startswith("baseline"):
        baseline(cfg)
    elif task_name.startswith("cal_metric"):
        cal_metric(cfg)
    elif task_name.startswith("plot"):
        plot(cfg)
    elif task_name.startswith("search"):
        search(cfg)
    elif task_name.startswith("display_result"):
        display_result(cfg)
    elif task_name.startswith("collect_result"):
        collect_result(cfg)

    elif task_name.startswith("test"):
        test(cfg)
    # elif task_name.startswith("do_ddim_inversion"):
    #     do_ddim_inversion(cfg)
    else:
        raise NotImplementedError(
            f"Unsupported task: `{task_name}`. "
        )


@hydra.main(version_base = None, config_path = "config", config_name = "cfg")
def main(
    cfg: DictConfig
):
    cfg = OmegaConf.create(cfg)
    cfg = OmegaConf.to_container(
        cfg, 
        resolve = True
    )

    set_global_variable_dict(cfg)

    exp_name = get_global_variable("exp_name")
    logger(f"Start experiment `{exp_name}`. ")

    run_task(cfg)

    logger(f"Experiment `{exp_name}` finished. ")

    # `main()` done
    pass


if __name__ == "__main__":
    main()
    