import sys
sys.path.append('../src/')

import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import torch
import pandas as pd

from tqdm import tqdm

import image_analysis
from pd_utils import load_configs

DATAPATH = "/srv/thetis2/gc453/iclr_ntk_runs/"
SAVEPATH = "/srv/thetis2/gc453/ntk_run_analysis/"
SIDELENGTH = 256

_ALL_FILTERS = [
    'model.architecture', 
    'model.hidden_features', 
    'model.hidden_layers', 
    'dset', 
    'model.architecture.scale', 
    'seed'
]





if __name__ == "__main__":
    config_df = load_configs(DATAPATH)

    extra_conf_cols = ['n_components', 'n_epoch', 'n_steps_to_collect', 'optimizer.choice', 'optimizer.lr', 'path']
    data_cols = [c for c in config_df.columns if c not in extra_conf_cols]

    extra_conf_df = config_df[extra_conf_cols]
    config_df = config_df[data_cols]

    Rv = torch.load(os.path.join(SAVEPATH, 'validation_residuals.pth'))
    determinism_df = config_df[_ALL_FILTERS].copy()
    _scores = []
    for i, rv in tqdm(enumerate(Rv), total=len(Rv)):
        sig = abs(rv).reshape(SIDELENGTH, SIDELENGTH)
        _scores.append(image_analysis.average_pcorrelation_score(sig, 15))
    determinism_df['pcorr_scores'] = _scores

    res_save_path = os.path.join(SAVEPATH, 'validation_pcorr_scores.json')
    determinism_df.to_json(res_save_path, orient='records', indent=2)

