import os
import pandas as pd
import numpy as np

def process_csv_files(directory, task_name, seeds, model):
    
    seed_data = {}

    for seed in seeds:
        file = os.path.join(directory, f"addmnist_{model}_{task_name}_{seed}_which_c_[-1]_.csv") # _optimal.csv")
        
        if os.path.exists(file):
            df = pd.read_csv(file)
            if df["yf1"].iloc[0] < 0.90:
                continue
            seed_data[seed] = df

    all_data = pd.concat(seed_data.values(), keys=seed_data.keys(), names=["Seed"])

    mean_values = all_data.groupby(level=1).mean()
    std_values = all_data.groupby(level=1).std()

    print("\nMean values:")
    print(mean_values)
    
    print("\nStandard deviation values:")
    print(std_values)


def process_csv_files_ood(directory, task_name, seeds):
    seed_data_test = {}
    seed_data_ood = {}

    for seed in seeds:
        file = os.path.join(directory, f"addmnist_mnistcbm_{task_name}_{seed}_which_c_[3, 4, 1, 0, 2, 8, 9, 5, 6, 7]_c_sup_1.0_k_sup_1.0_multi_linear_optimal_optimal.csv")
        
        print(file)
        if os.path.exists(file):
            df = pd.read_csv(file)

            if len(df) >= 2:
                seed_data_test[seed] = df.iloc[0]
                seed_data_ood[seed] = df.iloc[1]
    
    if seed_data_test:
        test_df = pd.DataFrame.from_dict(seed_data_test, orient="index")
        print(test_df)
        quit()
        mean_test = test_df.mean()
        std_test = test_df.std()

        print("\nTest Set - Mean values:")
        print(mean_test)
        
        print("\nTest Set - Standard deviation values:")
        print(std_test)

    if seed_data_ood:
        ood_df = pd.DataFrame.from_dict(seed_data_ood, orient="index")
        mean_ood = ood_df.mean()
        std_ood = ood_df.std()

        print("\nOOD Set - Mean values:")
        print(mean_ood)

        print("\nOOD Set - Standard deviation values:")
        print(std_ood)

# process_csv_files("best_models_addmnist", "sumparity", [1011, 1213, 1415, 1617, 1819, 2021], "mnistsenn")
# process_csv_files_ood("best_models_addmnist", "sumparityrigged", [3839, 4041, 4243, 4445, 4647, 4849])
process_csv_files("best_models_addmnist", "sumparity", [1011, 1213, 1415, 1617, 1819], "mnistdpl")