#!/usr/bin/env python3
import re
import ast
import numpy as np
import scipy
import math

activations = ["sTanh", "SiLU", "Tanh", "Tanhshrink"]
init_schemes = ["sTanh", "default"]
seeds = list(range(1234, 1254))
data = {}

for init in init_schemes:
    data[init] = {}
    for activation in activations:
        data[init][activation] = {}
        for seed in seeds:
            smallest_loss = 1e4
            data[init][activation][seed] = {"E_loss": [], "F_loss": [], "E_RMSE_train": [], "E_RMSE_test": [], "F_RMSE_train": [], "F_RMSE_test": [], "rank": None}
            with open(f"activation_function/{activation}_{init}_{seed}.out", "r") as file:
                # skip header and settings
                for _ in range(94):
                    next(file)

                done = False
                for line in file:
                    if line.startswith("Init done"):
                        tmp = next(file)
                        tmp = ast.literal_eval(tmp)
                        data[init][activation][seed]["rank"] = sum(tmp[0][1:])
                    epoch = line.split("|")[0].strip()
                    if epoch.isdigit() and int(epoch) % 10 == 0 or line.startswith("Final"):
                        line_split = re.findall(r"\S+\|?", line)
                        E_RMSE_test = float(line_split[10])
                        F_RMSE_test = float(line_split[14])
                        loss = E_RMSE_test**2 * 10.9**2 + F_RMSE_test**2
                        if loss < smallest_loss:
                            smallest_loss = loss
                            data[init][activation][seed]["E_loss"].append(float(line_split[2]))
                            data[init][activation][seed]["F_loss"].append(float(line_split[4]))
                            data[init][activation][seed]["E_RMSE_train"].append(float(line_split[8]))
                            data[init][activation][seed]["E_RMSE_test"].append(float(line_split[10]))
                            data[init][activation][seed]["F_RMSE_train"].append(float(line_split[12]))
                            data[init][activation][seed]["F_RMSE_test"].append(float(line_split[14]))


def round_up(number, decimals=0):
    factor = 10 ** decimals
    return math.ceil(number * factor) / factor

ranks_all = []
losses_all = []        
row_format_header = "{:<25}" + "{:>20}" * 4
row_format_body = "{:<25}" + "{:>20}" * 4
print(row_format_header.format("Act. fct., init. scheme", "E RMSE test", "F RMSE test", "Test Loss", "Rank"))
for init in init_schemes:
    for activation in activations:
        E_RMSE_test_final = np.array([data[init][activation][seed]["E_RMSE_test"][-1] for seed in seeds])
        F_RMSE_test_final = np.array([data[init][activation][seed]["F_RMSE_test"][-1] for seed in seeds])
        ranks = [data[init][activation][seed]["rank"] for seed in seeds]
        ranks_all.append(np.mean(ranks))
        losses_all.append(np.mean(E_RMSE_test_final**2 * 10.9**2 + F_RMSE_test_final**2))
        if init == "sTanh" and activation in ("sTanh", "SiLU"):
            tmp[activation] = E_RMSE_test_final**2 * 10.9**2 + F_RMSE_test_final**2
        print(row_format_body.format(
                f"{activation}, {init}", 
                f"{np.mean(E_RMSE_test_final):.3f} ± {np.std(E_RMSE_test_final):.4f}", 
                f"{np.mean(F_RMSE_test_final):.3f} ± {np.std(F_RMSE_test_final):.4f}",
                f"{np.mean(E_RMSE_test_final**2 * 10.9**2 + F_RMSE_test_final**2):.3f} ± {round_up(np.std(E_RMSE_test_final**2 * 10.9**2 + F_RMSE_test_final**2), 3)}",
                f"{np.mean(ranks):.1f} ± {round_up(np.std(ranks), 1)}",
            )
        )

print(scipy.stats.spearmanr(ranks_all, losses_all))

average_rank = []
for seed in seeds:
    losses = []
    ranks = []
    for init in init_schemes:
        for activation in activations:
            E_RMSE = data[init][activation][seed]["E_RMSE_test"][-1]
            F_RMSE = data[init][activation][seed]["F_RMSE_test"][-1]
            losses.append(E_RMSE**2 * 10.9**2 + F_RMSE**2)
            ranks.append(data[init][activation][seed]["rank"])
    average_rank.append(scipy.stats.rankdata(losses, method="dense"))
