# read in all result files from the following executions:
#python 02_paper_experiments/033_unimodal_omics.py --modality 'rna' --n_batches 3 --threshold_type 'absolute' --seed 0
#python 02_paper_experiments/033_unimodal_omics.py --modality 'rna' --n_batches 3 --threshold_type 'absolute' --seed 42
#python 02_paper_experiments/033_unimodal_omics.py --modality 'rna' --n_batches 3 --threshold_type 'absolute' --seed 9306
#python 02_paper_experiments/033_unimodal_omics.py --modality 'protein' --n_batches 3 --threshold_type 'absolute' --seed 0
#python 02_paper_experiments/033_unimodal_omics.py --modality 'protein' --n_batches 3 --threshold_type 'absolute' --seed 42
#python 02_paper_experiments/033_unimodal_omics.py --modality 'protein' --n_batches 3 --threshold_type 'absolute' --seed 9306

# report the mean +- SEM for each r square threshold and modality

import pandas as pd
import numpy as np

# Define the result files to read
result_files = [
    "03_results/paper_results/unimodal_omics/unimodal_omics_rna_noisy_n3batches_thresh-absolute_seed-0.csv",
    "03_results/paper_results/unimodal_omics/unimodal_omics_rna_noisy_n3batches_thresh-absolute_seed-42.csv",
    "03_results/paper_results/unimodal_omics/unimodal_omics_rna_noisy_n3batches_thresh-absolute_seed-9306.csv",
    "03_results/paper_results/unimodal_omics/unimodal_omics_protein_noisy_n3batches_thresh-absolute_seed-0.csv",
    "03_results/paper_results/unimodal_omics/unimodal_omics_protein_noisy_n3batches_thresh-absolute_seed-42.csv",
    "03_results/paper_results/unimodal_omics/unimodal_omics_protein_noisy_n3batches_thresh-absolute_seed-9306.csv",
]

# Read and combine all result files
all_results = pd.concat([pd.read_csv(file) for file in result_files], ignore_index=True)

# Group by modality and r_square_threshold, then calculate mean and SEM
summary = all_results.groupby(['modality', 'r_square_threshold']).agg(
    mean_final_ranks=('final_ranks', 'mean'),
    sem_final_ranks=('final_ranks', lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))
).reset_index()

# Print the summary
print("Mean ± SEM for each r_square_threshold and modality:")
for _, row in summary.iterrows():
    print(f"Modality: {row['modality']}, R² Threshold: {row['r_square_threshold']}, Final Ranks: {row['mean_final_ranks']:.2f} ± {row['sem_final_ranks']:.2f}")