import sys
sys.path.append('/relnet')

from relnet.io.file_paths import FilePaths
from relnet.io.storage import EvaluationStorage

from itertools import product

import argparse

def main():
    parser = argparse.ArgumentParser(description="Plain utility to check best validation losses so far.")
    parser.add_argument("--experiment_id", required=True, help="experiment id to use")
    args = parser.parse_args()

    experiment_id = args.experiment_id
    fp = FilePaths('/experiment_data', experiment_id, setup_directories=False)
    # Chiesa_ssp_drgnr

    storage = EvaluationStorage(fp)
    experiment_details = storage.get_experiment_details(experiment_id)

    agent_names = list(experiment_details['agents'])
    experiment_conditions = experiment_details['experiment_conditions']
    objective_functions = experiment_details['objective_functions']
    network_generators = experiment_details['network_generators']
    model_seeds = experiment_conditions['experiment_params']['model_seeds']

    for objective_function in objective_functions:
        for network_generator in network_generators:
            print(f"=================")
            print(f"{network_generator},{objective_function}")
            print(f"=================")

            best_comb_vals = {}
            for agent_name in agent_names:
                agent_grid = experiment_conditions['hyperparam_grids'][objective_function][agent_name]
                num_hyperparam_combs = len(list(product(*agent_grid.values())))

                full_out_str = ""
                best_comb_val = float("inf")
                best_comb_vals[agent_name] = best_comb_val

                for comb in range(num_hyperparam_combs):
                    try:
                        df = storage.fetch_eval_curves(agent_name, comb, fp, objective_function, network_generator, model_seeds, False, nrows_to_skip=0)
                    except ValueError:
                        continue

                    if len(df) > 0:
                        num_started = 0
                        num_total = len(model_seeds)

                        out_str_started = ""

                        best_so_far = []

                        for seed in model_seeds:
                            df_subset = df[(df['model_seed'] == seed) &
                                           (df['network_generator'] == network_generator) &
                                           (df['objective_function'] == objective_function)]

                            if len(df_subset) > 0:
                                best_perf = df_subset['perf'].min()
                                training_step = df_subset['timestep'].max()
                                out_str_started += f"{best_perf:.3f} [{training_step}; {seed}],  "
                                num_started +=1
                                best_so_far.append(best_perf)

                        if len(out_str_started) > 0:
                            comb_avg = sum(best_so_far) / len(best_so_far)
                            best_comb_val = min(best_comb_val, comb_avg)
                            best_comb_vals[agent_name] = best_comb_val

                            full_out_str += (f"{comb}: <<{comb_avg:.4f}>> avg. training started: {num_started / num_total * 100:.2f}%. {out_str_started}")
                            if comb < num_hyperparam_combs - 1:
                                full_out_str += "\n"

                print(f"=================")
                print(f"<<{agent_name}>>")
                print(f"=================")
                print(full_out_str)

            print(f"*****************")
            print(f"intermediary results")
            print(f"*****************")
            for agent_name in agent_names:
                print(f"{best_comb_vals[agent_name]:.4f}: <<{agent_name}>>")

if __name__ == "__main__":
    main()
