import numpy as np

'''
This file contains the raw metrics across different random seeds (for the TUAB dataset downstream task) and
different cross validation folds (for the Neonate dataset downstream task), and calculates mean and standard deviation.'''


####LARGE MODEL#####
# TUAB METRICS
# test_acc =     [0.8102418780326843, 0.8091338872909546, 0.8119173049926758]
# test_aupr = [0.8893928527832031 ,0.8853086829185486 ,0.8905146718025208]
# test_auroc = [ 0.8821651935577393, 0.8773523569107056, 0.883364200592041]
# test_balanced_acc = [0.8033280372619629, 0.8042320013046265, 0.8043872117996216]
# test_loss_epoch = [0.4913586974143982, 0.503070056438446, 0.48808521032333374]

# NEONATE METRICS
# test_acc = [ 0.8585270047187805, 0.9203433990478516, 0.9180701971054077, 0.9084784388542175]
# test_aupr =[0.6988531351089478, 0.7146546840667725, 0.830643892288208, 0.6192806959152222]
# test_auroc = [0.8587494492530823, 0.9171226024627686, 0.9269993305206299, 0.8510891199111938]
# test_balanced_acc = [0.7747777700424194, 0.8328449726104736, 0.8244763612747192, 0.6948388814926147]
# test_loss_epoch = [0.43342915177345276, 0.33717605471611023, 0.34202757477760315, 0.3721396028995514]

##### SMALL MODEL #####
# TUAB
# test_acc = [0.8072287440299988, 0.8087691068649292, 0.8112687468528748]
# test_aupr = [0.8911492228507996, 0.8904352188110352, 0.8920863270759583]
# test_auroc = [0.8864097595214844, 0.8810165524482727, 0.8874328136444092]
# test_balanced_acc = [0.7990596294403076, 0.8058445453643799, 0.8074133396148682]
# test_loss_epoch = [0.48105403780937195, 0.5040305256843567, 0.48118656873703003]

# NEONATE
test_acc = [0.9055035710334778, 0.9303014278411865, 0.9330905079841614, 0.852310299873352]
test_aupr = [0.5775789618492126, 0.8505375981330872, 0.7070022225379944, 0.6753787994384766]
test_auroc = [0.8274600505828857, 0.9343030452728271, 0.9203506708145142, 0.8514599800109863]
test_balanced_acc = [0.6717950105667114, 0.823998212814331, 0.8286746144294739, 0.7590943574905396]
test_loss_epoch = [0.3826697766780853, 0.3266799747943878, 0.32024645805358887, 0.4443998634815216]


# Function to calculate mean and std
def calculate_mean_std(data):
    mean = np.mean(data)
    std = np.std(data)
    return mean, std

# Calculate for each array
metrics = {
    "test_acc" : test_acc,
    "test_aupr": test_aupr,
    "test_auroc": test_auroc,
    "test_balanced_acc": test_balanced_acc,
    "test_loss_epoch": test_loss_epoch
}

for metric_name, data in metrics.items():
    mean, std = calculate_mean_std(data)
    print(f"{metric_name} - Mean: {mean:.4f}, Std: {std:.4f}")
