import argparse
import os
from typing import Dict, List

from research.flamby.visualization_scripts.average_performance_configs import fed_isic_file_names_to_info


def process_results_dict(results_lines: List[str]) -> Dict[str, float]:
    results_dict: Dict[str, float] = {}
    for results_line in results_lines:
        split_line = results_line.split(":")
        results_dict[split_line[0]] = float(split_line[1])
    return results_dict


def process_results_to_matlab_string(
    chart_method_names: List[str], chart_means: List[float], chart_std_devs: List[float]
) -> None:
    out_string = "method = {'"
    chart_method_names_joined = "', '".join(chart_method_names)
    out_string = f"{out_string}{chart_method_names_joined}"
    out_string = f"{out_string}'}};\n\n"
    out_string = f"{out_string}means = ["
    chart_means_join = ", ".join([str(mean) for mean in chart_means])
    out_string = f"{out_string}{chart_means_join}"
    out_string = f"{out_string}];\n\n"
    out_string = f"{out_string}std_devs = ["
    chart_std_devs_joined = ", ".join([str(std_dev) for std_dev in chart_std_devs])
    out_string = f"{out_string}{chart_std_devs_joined}"
    out_string = f"{out_string}];\n"
    print(out_string)


def main(results_dir: str) -> None:
    chart_method_names = []
    chart_means = []
    chart_std_devs = []
    for file_info in fed_isic_file_names_to_info:
        file_name, chart_method_name, info_keys = file_info
        mean_info_keys, std_dev_info_keys = info_keys
        chart_method_names.append(chart_method_name)
        with open(os.path.join(results_dir, file_name), "r") as file_handle:
            results_lines = file_handle.readlines()
            results_dict = process_results_dict(results_lines)
            for mean_info_key, std_dev_info_key in zip(mean_info_keys, std_dev_info_keys):
                chart_means.append(results_dict[mean_info_key])
                chart_std_devs.append(results_dict[std_dev_info_key])
    process_results_to_matlab_string(chart_method_names, chart_means, chart_std_devs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Local Training Main")
    parser.add_argument(
        "--results_dir",
        action="store",
        type=str,
        help="Path to the results logs generated by the evaluate_on_holdout.py script",
        required=True,
    )
    args = parser.parse_args()
    main(args.results_dir)
