"""

Compare the impact of train/test using exact/refineimate methods


"""


# -- sys --
import os
import numpy as np
import pandas as pd
from easydict import EasyDict as edict

# -- testing --
from dev_basics.trte import test,bench

# -- plotting --
# import stnls_paper
# from stnls_paper import plots

# -- caching results --
import cache_io


def main():

    # -- start info --
    verbose = True
    pid = os.getpid()
    print("PID: ",pid)

    # -- get/run experiments --
    refresh = True
    def clear_fxn(num,cfg): return False
    read_test = cache_io.read_test_config.run
    exps = read_test("exps/trte_rvrt/test.cfg",
                     ".cache_io_exps/trte_rvrt/test",reset=refresh,skip_dne=refresh)
    exps,uuids = cache_io.get_uuids(exps,".cache_io/trte_rvrt/test",
                                    reset=refresh,no_config_check=refresh)
    print(len(exps))

    # -- run exps --
    results = cache_io.run_exps(exps,test.run,uuids=uuids,
                                name=".cache_io/trte_rvrt/test",
                                version="v1",skip_loop=False,clear_fxn=clear_fxn,
                                clear=False,enable_dispatch="slurm",
                                records_fn=".cache_io_pkl/trte_rvrt/test.pkl",
                                records_reload=True,to_records_fast=False)

    # -- get bench--
    # bench.print_summary(exps[0],(1,4,4,256,256))

    # -- view --
    print(len(results))
    if len(results) == 0: return
    afields = ['psnrs','ssims','strred']
    gfields = ["sigma",'vid_name']#'pretrained_path']
    gfields0 = [gfields[i] for i in range(len(gfields)-1)]
    agg_fxn = lambda x: np.mean(x)
    for f in afields: results[f] = results[f].apply(np.mean)
    results = results.groupby(gfields).agg({k:agg_fxn for k in afields})
    results = results.reset_index()[gfields + afields]
    print(len(results))
    print(results)
    print(results['psnrs'].mean(),results['ssims'].mean())


if __name__ == "__main__":
    main()
