""" Accuracy comparision

Copyright (c) 2025 Anonymous Authors
"""
import os
import pandas as pd
from collections import defaultdict


def get_accuracy_logging_list():
    train_logging_list = ['last.pth.tar']
    results_logging_list = []
    
    return train_logging_list, results_logging_list


def _accuracy_comparison(file_group, merging_column, comparing_column, output_file_table, output_file_latex):
    total_rows = []
    for group_name, file_list in file_group.items():
        df_list = [pd.read_csv(cur) for cur in file_list]
        # Extract the last epoch from each DataFrame
        group_df = pd.concat([df.iloc[-1].to_frame().T for df in df_list], ignore_index=True)
        assert group_df[merging_column].nunique() == 1, f"{merging_column} is not unique"
        group_df[comparing_column] = group_df[comparing_column].astype(float)
        if len(group_df[comparing_column]) == 1:
            new_row = {"name": group_name, "mean": group_df[comparing_column].mean()}
        else:
            new_row = {"name": group_name, "mean": group_df[comparing_column].mean(), "std": group_df[comparing_column].std()}
        total_rows.append(new_row)
    final_df = pd.DataFrame(total_rows)

    # Save to CSV or TXT
    final_df.to_csv(output_file_table)

    # save latex format into TXT
    final_df.to_latex(output_file_latex, float_format="%.2f")
    
    return


def accuracy_comparison(results_list, output_directory, accuracy_comparison):
    # results_list 
    file_name = 'summary.csv'
    merging_column = 'epoch'
    comparing_column = 'eval_top1'
    file_group = defaultdict(list)
    for cur_path in results_list:
        last_folder = os.path.basename(cur_path)
        name_seed = last_folder.rsplit('-', 1)
        cur_directory = os.path.join(cur_path, file_name)
        if len(name_seed) == 2 and name_seed[1].isdigit():
            key = name_seed[0]
        else:
            key = last_folder
        file_group[key].append(cur_directory)

    lengths = [len(v) for v in file_group.values()]
    assert len(set(lengths)) == 1, "Accuracy comparison : expected all experiment includes same amount of seeds"

    output_file_table = os.path.join(output_directory, 'accuracy.csv')
    output_file_latex = os.path.join(output_directory, 'accuracy.tex')
    _accuracy_comparison(file_group, merging_column, comparing_column, output_file_table, output_file_latex)
    return
    