import pickle
import glob, os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from env.pomnist.vis_util import add_single_run_stat_to_ax, add_multi_run_stat_to_ax
from config import ConfigPOMNIST, ConfigTrafficJunction
from constants import *


def parse_cfg(cfg):
    env_name = cfg[ENV_NAME]
    num_runs = cfg[NUM_RUNS]

    env_config_dict = cfg[COMMON_CONFIG].copy()
    env_config_dict.update(cfg[ENV_CONFIG][env_name])

    env_config_dict[MSG_SIZES] = [tuple(i) if type(i) == list else tuple([i,]) for i in env_config_dict[MSG_SIZES]]

    iter_keys = [x for x, y in env_config_dict.items() if type(y) == list]
    iter_dict = {}
    for key in iter_keys:
        iter_dict[key] = env_config_dict[key]

    if env_name == POMNIST:
        config = ConfigPOMNIST()
    elif env_name == TRAFFIC_JUNCTION:
        config = ConfigTrafficJunction()
    else:
        raise ValueError(f'Invalid environment name "{env_name}"! '
                         f'Please check "config.yaml" and give a valid environment name.')

    for key, value in env_config_dict.items():
        setattr(config, key, value)

    return num_runs, iter_dict, config


def get_run_stats(exp_folder):
    """
    Collects the run stats from all different configurations in the experiment file

    :param exp_folder: Experiment folder name containing folders for different configurations
    """
    exp_dir = os.path.join("runs", exp_folder)
    file_dirs = glob.glob(exp_dir + "/*/*.pkl")
    file_dirs.sort(key=os.path.getmtime)

    all_run_stats, all_eval_stats, config = [], [], []
    for file in file_dirs:
        with open(file, "rb") as f:
            file_content = pickle.load(f)

        run_config = file_content['config']
        run_stats = file_content['run_stats']
        eval_stats = file_content['eval_stats']
        all_run_stats.append(run_stats)
        all_eval_stats.append(eval_stats)
        config.append(run_config)

    # Get the parameters which has been changed in the experiments for legend
    config_attr = vars(config[0]).keys()
    compared_params = []
    for attr in config_attr:
        attr_list = [getattr(h, attr) for h in config]
        if not all(element == attr_list[0] for element in attr_list):
            compared_params += [attr]
    if "run_comment" in compared_params:
        compared_params.remove('run_comment')
    compared_params = [item for item in compared_params if "_schedule" not in item]

    return all_run_stats, all_eval_stats, config, compared_params


def visualize_experiment_results(all_run_stats, all_configs, compared_params=None, stat_label=AVG_RETURNS):
    """
    Visualizes the given stats per iteration (with std) of multiple or single runs.

    :param all_stats: All run stats
    :param all_configs: The configs from all configurations
    :param compared_params: The list of parameters changed in the experiments, to be used in legends
    :param stat_label: The stats to be plotted
    """
    fig, ax = plt.subplots(1)

    for setting in range(len(all_run_stats)):
        setting_run_stats = all_run_stats[setting]
        config = all_configs[setting]
        num_runs = len(setting_run_stats) if type(setting_run_stats[0]) is not dict else 1
        mean_stat = np.zeros((config.num_iterations, num_runs))

        if num_runs == 1:
            for i in range(0, config.num_iterations):
                mean_stat[i] = setting_run_stats[i][stat_label].mean()
            add_single_run_stat_to_ax(mean_stat, config, ax, compared_params)
        else:
            for r in range(num_runs):
                run_stats = setting_run_stats[r]
                for i in range(0, config.num_iterations):
                    mean_stat[i, r] = run_stats[i][stat_label].mean()
            add_multi_run_stat_to_ax(mean_stat, config, ax, compared_params)

    assert stat_label in PLOT_LABEL_DICT, "Please add the stat and plot labels for requested results in 'constants.py'!"

    fig.suptitle(PLOT_LABEL_DICT[stat_label][TITLE])
    ax.set_xlabel(PLOT_LABEL_DICT[stat_label][XLABEL])
    ax.set_ylabel(PLOT_LABEL_DICT[stat_label][YLABEL])
    ax.legend()
    fig.tight_layout()
    plt.show()
    return fig


def create_metrics_table(all_eval_stats, all_config):
    df_metrics = pd.DataFrame()

    for setting in range(len(all_eval_stats)):
        setting_eval_stats = all_eval_stats[setting]
        config = all_config[setting]
        num_runs = 1 if isinstance(setting_eval_stats, dict) else len(setting_eval_stats)

        # base metrics
        metrics_list = ["agent_mean_return", "num_drops", "channel_util", "mean_selected_msg_size"]

        # environment-specific metrics
        if config.env_name == POMNIST:
            metrics_list += ["positive_listening", "positive_signaling", "negative_listening"]
        elif config.env_name == TRAFFIC_JUNCTION:
            metrics_list += ["traffic_junction_success"]

        # fix that removes torch tensors from stats if necessary
        def remove_torch(val):
            if isinstance(val, torch.Tensor):
                return val.numpy()
            else:
                return val

        eval_modes = list(setting_eval_stats.keys() if num_runs == 1 else setting_eval_stats[0].keys())
        for eval_mode in eval_modes:
            for metric in metrics_list:
                row = {}
                row["Settings"] = config.run_comment
                row["Metric_Name"] = ('' if eval_mode == 'test' else f"{eval_mode}_") + metric
                row["Message_Type"] = config.message_mode
                run_metrics = []
                if num_runs == 1:
                    if metric in setting_eval_stats[eval_mode]:
                        run_metrics += [np.mean(remove_torch(setting_eval_stats[eval_mode][metric]))]
                else:
                    for r in range(num_runs):
                        if metric in setting_eval_stats[r][eval_mode]:
                            run_metrics += [np.mean(remove_torch(setting_eval_stats[r][eval_mode][metric]))]
                row["Metric_MeanValue"] = np.mean(run_metrics)
                row["Metric_StandardDeviation"] = np.std(run_metrics)

                df_metrics = pd.concat((df_metrics, pd.DataFrame([row])), ignore_index=True)

    return df_metrics
