from itertools import product
from fairpate_tabular.utils import write_results

def get_pd_query(args_dict, args_changed):
        def add_quotes(x):
            if isinstance(x, str):
                return f'"{x}"'
            else:
                return x
        experiment_settings_keys = set(['dataset', 'num_teachers', 'threshold', 'fairness_threshold', 
                                    'sigma_threshold', 'sigma_gnmax', 'budget', 'delta', 'seed'])
        query = []
        for key in experiment_settings_keys:
            if key in args_changed:
                query.append(f'{key} == {add_quotes(args_changed[key])}')
            else:
                query.append(f'{key} == {add_quotes(args_dict[key])}')

        query = ' and '.join(query)
        return query


def run_grid_search(args, run_experiment, results_db, non_None_list_args):
    all_experiments_keys = product(*[[key for _ in value] for key, value in non_None_list_args.items()])
    all_experiments_values = product(*non_None_list_args.values())

    try:
        for fields, values in zip(all_experiments_keys, all_experiments_values):
            print("Running experiment with:", end=" ")
            args_changed = dict()
            for f_id, f in enumerate(fields):
                key_to_set = f.replace("list_", "").strip()
                setattr(args, key_to_set, values[f_id])
                args_changed[key_to_set] = values[f_id]
                print(f"{key_to_set}={values[f_id]}", end=" ")
            print()
            try:
                query = get_pd_query(vars(args), args_changed)
                if results_db.query(query).shape[0] > 0:
                    print("\tResult Exists. Skipping...")
                    continue
                else:
                    student_model_validation_accuracy, validation_dem_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_dem_parity = run_experiment(args, lambda *x, **y: None, results_db)

                    write_results(args, student_model_validation_accuracy, validation_dem_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_dem_parity)
            except ValueError:
                student_model_validation_accuracy, validation_dem_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_dem_parity = None, None, None, None, None, None, None
                
                write_results(args, student_model_validation_accuracy, validation_dem_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_dem_parity)
    except KeyboardInterrupt:
        print("Keyboard Interrupt")
    finally:
        results_db.to_parquet(args.results_db_path)