#!/usr/bin/env python3

import re
import numpy as np
import scipy
import ast
import math

seeds = range(1234, 1254)
activation_functions = ["SiLU", "ReLU", "Tanh", "Tanhshrink"]
weight_initialization_schemes = ["xavier", "default", "uniform"]

data = {}
for fct in activation_functions:
    for init in weight_initialization_schemes:
        data[f"{fct}_{init}"] = {"accuracy": [], "rank": [], "loss": []}
        for seed in seeds:
            with open(f"results/activation_function/{fct}_{init}_{seed}.out") as file:
                for line in file:
                    if line.startswith("##################Effective Rank##################"):
                        line = next(file)
                        tmp = ast.literal_eval(line)
                        line = next(file)
                        data[f"{fct}_{init}"]["rank"].append(float(line))
                    if line.startswith("Test accuracy"):
                        accuracy, loss = re.findall("(\d+.\d+)", line)
                        data[f"{fct}_{init}"]["accuracy"].append(float(accuracy))
                        data[f"{fct}_{init}"]["loss"].append(float(loss))


def round_up(number, decimals=0):
    factor = 10 ** decimals
    return math.ceil(number * factor) / factor

row_format_header = "{:<25}" + "{:>20}" * 3
row_format_body = "{:<25}" + "{:>20}" * 3
print(row_format_header.format("Act. fct., init. scheme", "Accuracy", "Test loss", "Rank"))
for init in weight_initialization_schemes:
    for fct in activation_functions:
        mask = [True] * 20
        # skip the outlier for SiLU, Uniform (as detailed in the paper)
        if init == "uniform" and fct == "SiLU":
            mask[np.argmax(data[f'{fct}_{init}']['loss'])] = False
        print(row_format_body.format(
                f"{fct}, {init}", 
                f"{np.mean(np.array(data[f'{fct}_{init}']['accuracy'])[mask]):.3f} ± {round_up(np.std(np.array(data[f'{fct}_{init}']['accuracy'])[mask]), 3)}", 
                f"{np.mean(np.array(data[f'{fct}_{init}']['loss'])[mask]):.3f} ± {round_up(np.std(np.array(data[f'{fct}_{init}']['loss'])[mask]), 3)}",
                f"{np.mean(np.array(data[f'{fct}_{init}']['rank'])[mask]):.1f} ± {round_up(np.std(np.array(data[f'{fct}_{init}']['rank'])[mask]), 1)}",
            )
        )

correlation_loss = []
correlation_accuracy = []
for i in range(len(seeds)):
    losses = []
    ranks = []
    accuracy = []
    for init in weight_initialization_schemes:
        for fct in activation_functions:
            losses.append(data[f"{fct}_{init}"]["loss"][i])
            accuracy.append(data[f"{fct}_{init}"]["accuracy"][i])
            ranks.append(data[f"{fct}_{init}"]["rank"][i])
    correlation_loss.append(scipy.stats.spearmanr(losses, ranks).statistic)
    correlation_accuracy.append(scipy.stats.spearmanr(accuracy, ranks).statistic)
    
print(f"\nCorrelation between rank and loss: {np.mean(correlation_loss)} ± {np.std(correlation_loss)}")
print(f"Correlation between rank and accuracy: {np.mean(correlation_accuracy)} ± {np.std(correlation_accuracy)}")

