import os
import pickle
from pathlib import Path

import pandas as pd

import cfg


def create():
    df = pd.DataFrame()
    all_paths_merged_results = []

    # create list of folder paths with results
    old_path = 'C:\\Users\\blond\\PycharmProjects\\2025-01-16 Speedup factor\\speedup-factor-active-learning\\OLD\\exp'
    for folder in ['single_label', 'multi_label']:
        for dataset in ['carina', 'mscoco', 'reuters', 'scene']:
            path_old_df = Path(old_path, folder, dataset, 'results_merged.pkl')
            all_paths_merged_results.append(path_old_df)

    # compute recent values
    datasets = [d for d in os.listdir(cfg.path_exp) if os.path.isdir(os.path.join(cfg.path_exp, d))]
    # iterate over datasets
    for dataset in datasets:
        path_dataset = Path(cfg.path_exp, dataset)
        dataset_types = [d for d in os.listdir(path_dataset) if os.path.isdir(os.path.join(path_dataset, d))]
        # iterate over dataset types
        for dataset_type in dataset_types:
            path_new_df = Path(path_dataset, dataset_type, 'results_merged.pkl')
            all_paths_merged_results.append(path_new_df)

    # iterate over paths
    for path in all_paths_merged_results:
        print(f'Create one df: {path}')
        df_r = pd.read_pickle(path)

        # from single label data keep only results of all classes
        if 'single_label' in str(path) or '2k' in str(path):
            dataset_type = '2k'
            df_r['max_train_samples'] = 2000
            df_r = df_r[df_r['nr_classes'] == df_r['nr_classes'].max()]
            df_r = df_r.reset_index(drop=True)
        # from multi label data keep only results of all classes or 1 class
        elif 'multi_label' in str(path) or 'complete' in str(path):
            dataset_type = 'complete'
            df_r = df_r[(df_r['nr_classes'] == df_r['nr_classes'].max()) | (df_r['nr_classes'] == 1)]
            df_r = df_r.reset_index(drop=True)

        # Set col values if not there
        if 'dataset_type' not in df_r.columns:
            df_r['dataset_type'] = dataset_type

        if 'train_paradigm' not in df_r.columns:
            df_r['train_paradigm'] = 'sl'

        if 'weight_init' not in df_r.columns:
            df_r['weight_init'] = 'tl'

        if 'training' not in df_r.columns:
            df_r['training'] = 'frozen'

        if 'qm' not in df_r.columns and 'sampling_method' in df_r.columns:
            df_r['qm'] = df_r['sampling_method']

        if 'classes' not in df_r.columns:
            df_r['classes'] = 'ALL'

        # Keep only cols that are used for figures
        df_r = df_r[['dataset', 'dataset_type', 'random_seed', 'init_train_samples', 'add_train_samples',
                     'max_train_samples', 'train_paradigm', 'weight_init', 'training', 'qm',
                     'processing_time', 'processing_time_std', 'processing_time_sem', 'nr_classes',
                     'classes', 'nr_training_samples',
                     'evaluation_f1_macro', 'evaluation_f1_macro_std', 'evaluation_f1_macro_sem',
                     'evaluation_ceiling_f1_macro', 'evaluation_ceiling_f1_macro_std',
                     'evaluation_ceiling_f1_macro_sem',
                     'a_inf', 'a_0_p1', 'a_0_p2', 'b_p1', 'b_p2', 'sf_p1', 'sf_p2', 'best_p', 'sf_best_p'
                     ]]

        # update df based on df_r
        cols = ['dataset', 'dataset_type', 'add_train_samples', 'max_train_samples',
                'train_paradigm', 'weight_init', 'training', 'qm', 'nr_classes', 'classes']
        for index, row in df_r.iterrows():
            # append first row if df empty
            if df.empty:
                df = pd.concat([df, row.to_frame().T], ignore_index=True)
            # check if the results exoist from somewhere else
            else:
                # Create a boolean mask where all column values match
                match = (df[cols] == row[cols]).all(axis=1)

                if match.any():
                    index_to_overwrite = match[match].index[0]
                    df.loc[index_to_overwrite] = row
                else:
                    df = pd.concat([df, row.to_frame().T], ignore_index=True)

    # save df
    df.to_pickle(Path(cfg.path_exp, 'results.pkl'))


if __name__ == '__main__':
    create()