#!/usr/bin/env python3
import re
import numpy as np
import ast
import math

sizes = [
    "52_58_20", "52_58_40", "52_58_60", "52_58_80", "52_58_100",
    "52_18_60", "52_38_60", "52_78_60", "52_98_60",
    "102_61_44"
]
seeds = range(1234, 1254)
data = {}
num_param = {
    "52_58_20": 11509, "52_58_40": 12709, "52_58_60": 13909, "52_58_80": 15109, "52_58_100": 16309,
    "52_18_60": 9389, "52_38_60": 11649, "52_78_60": 16169, "52_98_60": 18429, "102_61_44": 22990
}

for size in sizes:
    data[size] = {}
    for seed in seeds:
        smallest_loss = 1e6
        data[size][seed] = {"E_loss": [], "F_loss": [], "E_RMSE_train": [], "E_RMSE_test": [], "F_RMSE_train": [], "F_RMSE_test": [], "rank": None}
        with open(f"model_size/trained_results/{size}_{seed}.out", "r") as file:
            # skip header and settings
            for _ in range(94):
                next(file)

            done = False
            for line in file:
                if line.startswith("Loaded"):
                    for _ in range(2):
                        line = next(file)
                    data[size][seed]["rank"] = sum(ast.literal_eval(line)[0])
                    for _ in range(8):
                        line = next(file)
                    continue
                epoch = line.split("|")[0].strip()
                if epoch.isdigit() and int(epoch) % 10 == 0 or line.startswith("Final"):
                        line_split = re.findall(r"\S+\|?", line)
                        E_RMSE_test = float(line_split[10])
                        F_RMSE_test = float(line_split[14])
                        loss = E_RMSE_test**2 * 10.9**2 + F_RMSE_test**2
                        if loss < smallest_loss:
                            smallest_loss = loss
                            data[size][seed]["E_loss"].append(float(line_split[2]))
                            data[size][seed]["F_loss"].append(float(line_split[4]))
                            data[size][seed]["E_RMSE_train"].append(float(line_split[8]))
                            data[size][seed]["E_RMSE_test"].append(float(line_split[10]))
                            data[size][seed]["F_RMSE_train"].append(float(line_split[12]))
                            data[size][seed]["F_RMSE_test"].append(float(line_split[14]))

def round_up(number, decimals=0):
    factor = 10 ** decimals
    return math.ceil(number * factor) / factor

row_format_body = "{:<25}" + "{:>20}" * 5
print(row_format_body.format("Layer sizes", "E RMSE test", "F RMSE test", "Test Loss", "Rank", "#Params"))
for size in sizes:
    E_RMSE_test_final = np.array([data[size][seed]["E_RMSE_test"][-1] for seed in seeds])
    F_RMSE_test_final = np.array([data[size][seed]["F_RMSE_test"][-1] for seed in seeds])
    ranks = [data[size][seed]["rank"] for seed in seeds]
    print(row_format_body.format(
            f"{size.replace('_', ', ')}", 
            f"{np.mean(E_RMSE_test_final):.4f} ± {round_up(np.std(E_RMSE_test_final), 4)}", 
            f"{np.mean(F_RMSE_test_final):.4f} ± {round_up(np.std(F_RMSE_test_final), 4)}",
            f"{np.mean(E_RMSE_test_final**2 * 10.9**2 + F_RMSE_test_final**2):.3f} ± {round_up(np.std(E_RMSE_test_final**2 * 10.9**2 + F_RMSE_test_final**2), 3)}",
            f"{np.mean(ranks):.1f} ± {round_up(np.std(ranks), 3)}",
            f"{num_param[size]:,}"
        )
    )
