from typing import Dict

import numpy as np

_MAIN_METRICS = [
    "accuracy_f1",
    "pearson_spearman",
    "matthews_correlation",
    "accuracy",
]


def get_best_metrics(metrics, best_metrics):
    if best_metrics is None:
        return metrics, True

    main_metric = None
    for metric in _MAIN_METRICS:
        if metric in metrics.keys():
            main_metric = metric
            break
    if main_metric is None:
        raise Exception(f"Unknown metrics: {list(metrics.keys())}")

    best_metric = best_metrics[main_metric]
    current_metric = metrics[main_metric]

    if best_metric < current_metric:
        return metrics, True
    else:
        return best_metrics, False


def average_metric_over_loaders(
    epoch_metrics: Dict[str, Dict[str, float]],
    pattern: str,
    dataset_name: str,
) -> Dict[str, float]:
    """
    Averages metrics over loaders which satisfies pattern

    Args:
        epoch_metrics: epoch metrics in format metrics[metric_name][loader_name]
        pattern: string, for example ``valid``
        dataset_name: string with dataset name
    Returns:
        Dict with val metrics
    """
    val_metrics = {}
    for metric_key, values_dict in epoch_metrics.items():
        to_average = []
        for loader_key, metric_value in values_dict.items():
            if pattern in loader_key:
                to_average.append(metric_value)
        val_metrics[metric_key] = float(np.mean(to_average))

    if dataset_name in ["mrpc", "qqp", "cb", "multirc", "record"]:
        accuracy = val_metrics["accuracy"]
        f1 = val_metrics["f1"]
        val_metrics["accuracy_f1"] = float(np.mean([accuracy, f1]))
    elif dataset_name == "stsb":
        pearson = val_metrics["pearson"]
        spearman = val_metrics["spearmanr"]
        val_metrics["pearson_spearman"] = float(np.mean([pearson, spearman]))

    return val_metrics


def flatten_metrics(epoch_metrics: Dict[str, Dict[str, float]]) -> Dict[str, float]:
    """
    Flatten dict with epoch metrics
    Args:
        epoch_metrics: epoch metrics in format metrics[metric_name][loader_name]

    Returns:
        Flatten dict metric[{loader_name}_{metric_name}]
    """
    flatten_dict = {}
    for metric_key, values in epoch_metrics.items():
        for loader_key, metric_value in values.items():
            flatten_dict[f"{loader_key}_{metric_key}"] = metric_value
    return flatten_dict
