import json
import os
import sys

TS_BASE_DIR = ".."
sys.path.append(TS_BASE_DIR)

from ts_main import run_ts, parse_input_dict
from ege_main import run_ege_sh
from baseline import random_baseline_general, keep_pareto_optimal, exhaustively_enumerate_lib


def run_mo_random_trials(input_dict, num_random_cycles=10, num_warmup_trials=10, num_ts_iterations=50_000):
    input_dict["num_warmup_trials"] = num_warmup_trials
    input_dict["num_ts_iterations"] = num_ts_iterations
    for i in range(0, num_random_cycles):
        random_baseline_general(input_dict, num_trials=num_ts_iterations,
                                outfile_name=f"benchmark_data/mo_random_{i + 1}.csv", filter_func=keep_pareto_optimal)


def run_momab_trials(input_dict, num_mo_ts_cycles=10, num_warmup_trials=10, num_ts_iterations=50_000, mode="mo_maximize_TTPFTS", enable_pareto_logging=True):
    input_dict["num_warmup_trials"] = num_warmup_trials
    input_dict["num_ts_iterations"] = num_ts_iterations
    input_dict["ts_mode"] = mode
    input_dict["enable_pareto_logging"] = enable_pareto_logging
    for i in range(0, num_mo_ts_cycles):
        input_dict["results_filename"] = f"benchmark_data/MO_TTPFTS/{mode}_{i+11}.csv"
        input_dict["pareto_log_filename"] = f"benchmark_data/MO_TTPFTS/pareto_logs_{mode}_{i+11}.parquet"
        run_ts(input_dict)


def run_TTPFTS_UQ_trails(input_dict, num_mo_ts_cycles=10, num_warmup_trials=10, num_ts_iterations=50_001, mode="mo_maximize_TTPFTS", enable_posterior_logging=True):
    input_dict["num_warmup_trials"] = num_warmup_trials
    input_dict["num_ts_iterations"] = num_ts_iterations
    input_dict["ts_mode"] = mode
    input_dict["enable_posterior_logging"] = enable_posterior_logging
    for i in range(0, num_mo_ts_cycles):
        input_dict["results_filename"] = f"benchmark_data/TTPFTS_UQ/{mode}_{i+41}.csv"
        input_dict["posterior_log_filename"] = f"benchmark_data/TTPFTS_UQ/posterior_logs_{mode}_e{i+41}"
        run_ts(input_dict)


def run_exhaustively_enumerate_lib(input_dict):
    # Check whether the file already exists
    if not os.path.exists("benchmark_data/quinazoline_exhaustive_library.csv"):
        exhaustively_enumerate_lib(input_dict, output_filename=f"benchmark_data/quinazoline_exhaustive_library.csv")


def run_ege_trials(input_dict, num_ege_cycles=10, num_iterations=50_000, mode="mo_maximize_EGE_SH"):
    input_dict["num_iterations"] = num_iterations
    input_dict["mode"] = mode
    input_dict["log_filename"] = "ege_logs.txt"
    for i in range(0, num_ege_cycles):
        input_dict["results_filename"] = f"benchmark_data/{mode}_{i + 1}.csv"
        run_ege_sh(input_dict)


logp_FP_quinazoline_json = """{
"reagent_file_list": [
        "TS_BASE_DIR/data/aminobenzoic_ok.smi",
        "TS_BASE_DIR/data/primary_amines_500.smi",
        "TS_BASE_DIR/data/carboxylic_acids_500.smi"
    ],
    "reaction_smarts": "N[c:4][c:3]C(O)=O.[#6:1][NH2].[#6:2]C(=O)[OH]>>[C:2]c1n[c:4][c:3]c(=O)n1[C:1]",
    "num_warmup_trials": 10,
    "num_ts_iterations": 50001,
    "evaluator_class_name": "LogPandFPEvaluator",
    "evaluator_arg": {"opt_logp" : 3.0, "dev_logp" : 1.0, "query_smiles" : "CCc1cccc2c(=O)n(C3CNC3)c([C@@H](C)N)nc12"},
    "ts_mode": "mo_maximize_TTPFTS",
    "log_filename": "mo_ts_logs.txt",
    "results_filename": "mo_ttpfts_results.csv"
}""".replace("TS_BASE_DIR", TS_BASE_DIR)


def main():
    os.makedirs("benchmark_data", exist_ok=True)
    quinazoline_dict = json.loads(logp_FP_quinazoline_json)
    parse_input_dict(quinazoline_dict)
    # run_mo_random_trials(quinazoline_dict)
    # run_momab_trials(quinazoline_dict, mode="mo_maximize_TTPFTS", enable_pareto_logging=True, num_mo_ts_cycles=90)
    run_TTPFTS_UQ_trails(quinazoline_dict, num_mo_ts_cycles=10)
    # run_exhaustively_enumerate_lib(quinazoline_dict)
    # run_ege_trials(quinazoline_dict)


if __name__ == "__main__":
    main()
