import argparse
import os
from typing import Dict, List

from research.flamby.visualization_scripts.model_generalization_configs import fed_isic_file_names_to_info


def process_results_dict(results_lines: List[str]) -> Dict[str, float]:
    results_dict: Dict[str, float] = {}
    for results_line in results_lines:
        split_line = results_line.split(":")
        results_dict[split_line[0]] = float(split_line[1])
    return results_dict


def process_results_to_matlab_string(
    chart_method_names: List[str], chart_means: List[List[float]], chart_variable_names: List[str]
) -> None:
    out_string = ""
    c_string = "C = ["
    variable_names_joined = "; ".join(chart_variable_names)
    c_string = f"{c_string}{variable_names_joined}];\n"
    for means, variable_name in zip(chart_means, chart_variable_names):
        means_joined = ", ".join([str(mean) for mean in means])
        out_string = f"{out_string}{variable_name} = [{means_joined}];\n"

    out_string = f"{out_string}\n{c_string}\n"
    out_string = f"{out_string}method = {{'"
    chart_method_names_joined = "', '".join(chart_method_names)
    out_string = f"{out_string}{chart_method_names_joined}"
    out_string = f"{out_string}'}};\n"
    print(out_string)


def main(results_dir: str) -> None:
    chart_method_names = []
    chart_means = []
    chart_variable_names = []
    for file_info in fed_isic_file_names_to_info:
        file_name, chart_method_name, variable_name, mean_info_keys = file_info
        chart_method_names.append(chart_method_name)
        chart_variable_names.append(variable_name)
        with open(os.path.join(results_dir, file_name), "r") as file_handle:
            results_lines = file_handle.readlines()
            results_dict = process_results_dict(results_lines)
            method_means = []
            for mean_info_key in mean_info_keys:
                method_means.append(results_dict[mean_info_key])
            chart_means.append(method_means)
    process_results_to_matlab_string(chart_method_names, chart_means, chart_variable_names)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Local Training Main")
    parser.add_argument(
        "--results_dir",
        action="store",
        type=str,
        help="Path to the results logs generated by the evaluate_on_holdout.py script",
        required=True,
    )
    args = parser.parse_args()
    main(args.results_dir)
