import re
import csv
import os

dataset_name = ['CESM2', 'data']

models = ['IDOL', 'TDRL', 'CARD', 'FITS', 'MICN', 'iTransformer', 'TimesNet', 'Autoformer']

lens = ['96_96', '96_192', '96_336']
results_path = f'./IDOL_results'
save_path = f'txt_compare.csv'
with open(save_path, 'w', newline='') as csvfile:
    csvfile.write(f'dataset,pred_len,')
    for model in models:
        csvfile.write(f'{model},,')
    csvfile.write(f'\n')
    csvfile.write(f',,')
    for model in models:
        csvfile.write(f'MSE,MAE,')
    csvfile.write(f'\n')
for name in dataset_name:
    for l in lens:
        with open(save_path, 'a', newline='') as csvfile:
            csvfile.write(f'{name}, {l},')
        for model in models:
            mse_list = []
            mae_list = []
            for seed in [2024, 2022, 2023]:
                if not os.path.exists(f"{results_path}/{name}/{model}/{l}_{seed}.txt"):
                    print(f'No such file or directory: {name}_{model}_{l}_{seed}')
                else:
                    with open(f"{results_path}/{name}/{model}/{l}_{seed}.txt", "r", encoding='utf-8') as f:  #打开文本
                        txt = f.read()   #读取文本
                        # print(data)
                        mse_values = re.findall(r'mse:([0-9.]+)', txt)
                        mse_values = [float(value) for value in mse_values]
                        mse_list.append(mse_values[-1])
                        mae_values = re.findall(r'mae:([0-9.]+)', txt)
                        mae_values = [float(value) for value in mae_values]
                        mae_list.append(mae_values[-1])
            if len(mse_list) == 3:
                average_mse = sum(mse_list) / len(mse_list)                    
                average_mae = sum(mae_list) / len(mae_list)
                if name == '1_train':
                    average_mse = average_mse * 100
                    average_mae = average_mae * 100
                average_mse_rounded = round(average_mse, 3)
                average_mae_rounded = round(average_mae, 3)
                with open(save_path, 'a', newline='') as csvfile:
                    # csvfile.write(f'{name},{l},{average_mse_rounded},{average_mae_rounded},')
                    csvfile.write(f'{average_mse_rounded},{average_mae_rounded},')
            else:              
                with open(save_path, 'a', newline='') as csvfile:
                        csvfile.write(f',,')
        with open(save_path, 'a', newline='') as csvfile:
            csvfile.write(f'\n')
            
print(f'OK')