import os
import json

import pandas as pd
import numpy as np

from tqdm import tqdm

def _KEEP_ONLY_SIREN(config):
    """
    Filter function to only include SIREN models.
    """
    return config['model.architecture'] != "SIREN"

def load_configs(path, exclude_config=_KEEP_ONLY_SIREN):
    configs = []
    #stat_data = []
    for _dir in tqdm(os.listdir(path), desc='Load configs'):
        _dir_path = os.path.join(path, _dir)
        if 'ntk_comp.h5' in os.listdir(_dir_path):
            with open(os.path.join(_dir_path, 'config.json'), 'r') as f:
                _config = json.load(f)

            #if _config['model.architecture'] != "SIREN":
            if exclude_config(_config):
                continue
            
            with open(os.path.join(_dir_path, 'report.json'), 'r') as f:
                _report = json.load(f)

            _config.update(_report)

            _config['path'] = _dir_path
            
            configs.append(_config)
            
    configs = pd.DataFrame(configs)
    return configs

def agg_func(x):
    mn = np.mean(x)
    sn = np.std(x)
    return f'${mn:.3e} \pm {sn:.3e}$'

def config_lookup(config, df):
    mask = pd.Series([True] * len(df))
    for key, value in config.items():
        mask &= (df[key] == value)

    return mask

def concat_dfs(df_list, cfg_list):
    # df_list contains a list of dataframes
    # cfg list contains a list of singlets
    full_df = []
    for cfg, df in zip(cfg_list, df_list):
        dfc = df.copy()
        for k, v in cfg.items():
            dfc[k] = [v] * len(dfc)
        full_df.append(dfc)

    return pd.concat(full_df).reset_index()