import time
import yaml
import itertools
import argparse

from experiments.helpers import *
from main import run, multi_run


def run_experiment_from_config(cfg, exp_dir):
    env_name = cfg[ENV_NAME]
    num_runs, iter_dict, config = parse_cfg(cfg)

    config.set_default_schedules()

    setting_list = list(itertools.product(*list(iter_dict.values())))

    for set in setting_list:
        for i, key in enumerate(iter_dict.keys()):
            if key == SPLITS:
                config.x_splits, config.y_splits = set[i]
            else:
                setattr(config, key, set[i])
        config.msg_decoding_mask_size = len(config.msg_sizes)

        config.run_comment = ""
        if env_name == POMNIST:
            config.run_comment = f"{config.x_splits}x{config.y_splits}"

        # Note: The run comment should allow to distinguish different configurations and is used as a directory name
        config.run_comment += f"_msgSizes={config.msg_sizes}_{config.num_iterations}iter" \
                              f"_{config.message_mode}"\
                              f"_{config.comm_channel_type}"\
                              f"{f'_commSize={config.comm_channel_size}' if config.comm_channel_size else ''}"\
                              f"{'_random_msg_size' if config.force_random_msg_size_selection else ''}"

        if config.comm_channel_type == ChannelType.SelectiveChannel:
            config.run_comment += f"_agents_{'_'.join(map(str, config.comm_channel_selective_allowed_agents))}"

        # Create folder for the current setting (given in detail in config.run_comment)
        folder_name = os.path.join(exp_dir, config.run_comment)
        os.makedirs(folder_name, exist_ok=True)

        # Run the experiment and save tensorboard files
        if num_runs == 1:
            all_run_stats, all_eval_stats, all_log_dirs = run(config, log_dir=folder_name)
        else:
            assert num_runs > 1, "Number of runs must be >= 1!"

            config.log_eval_plots = False
            all_run_stats, all_eval_stats, all_log_dirs = multi_run(config, num_runs, base_log_dir=folder_name)

        # Save stats in all.pkl file
        with open(folder_name + "/all_stats.pkl", "wb") as f:
            file_content = dict(
                config=config,
                run_stats=all_run_stats,
                eval_stats=all_eval_stats
            )
            pickle.dump(file_content, f)


def main(cfg_dir, exp_folder=None, stat_labels=AVG_RETURNS):
    """
    Runs the experiment with given configurations and generates plots for the experiment for given stat labels (eg.
    avg_return, avg_q_loss..)

    :param cfg_dir: The config file
    :param stat_labels: The stats to be plotted
    """
    if exp_folder is not None:
        print("Ignoring the config, since the experiment folder is given!")
        exp_dir = os.path.join("runs", exp_folder)
    else:
        with open(cfg_dir) as f:
            cfg = yaml.load(f, Loader=yaml.FullLoader)

        # Create a folder for the all experiment
        exp_folder = f"{int(time.time())}_" + cfg[EXP_TAG]
        exp_dir = os.path.join("runs", exp_folder)
        os.makedirs(exp_dir, exist_ok=True)

        run_experiment_from_config(cfg, exp_dir)

    exp_run_stats, exp_eval_stats, exp_config, compared_params = get_run_stats(exp_folder)
    for stat_label in stat_labels:
        fig = visualize_experiment_results(exp_run_stats, exp_config, compared_params, stat_label=stat_label)
        df_metrics = create_metrics_table(exp_eval_stats, exp_config)
        df_metrics.to_csv(os.path.join(exp_dir, f"metrics_table.csv"))
        fig.savefig(os.path.join(exp_dir, f"{stat_label}.png"))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-file", "--c", default="experiments/config.yaml", help="the path to the config file")
    parser.add_argument("--experiment-folder", "--ef",
                        help="folder name of the experiment results to generate the plot for")
    args = parser.parse_args()

    stats_to_plot = [AVG_RETURNS]
    main(args.config_file, args.experiment_folder, stats_to_plot)
