from sorcerun.git_utils import (
    is_dirty,
    get_repo,
    get_commit_hash,
    get_time_str,
    get_tree_hash,
)
import sys

from make_algorithm.baselines import DEFAULT_BASELINE_CONFIGS

repo = get_repo()
ROOT = repo.working_dir
if ROOT not in sys.path:
    sys.path.append(ROOT)

from globals import SIGN, SQRT, INV, PROOT, MCTS, _sign, _sqrt, _inv, _proot


init_mat_name = "wishart"
init_mat_config = dict(
    c=0.5,
    d=1500,
    eps=1e-6,
)

EXPLORE_K = 5
matrix_function = SQRT
device = "cuda"
precision = "double"
custom_loss = "False"

action_specs = {
    SIGN: [  # None second element means use default theta bounds
        (_sign("ns"), ([0,0],[5,5])),
        (_sign("newton"), ([0], [40])),
        (_sign("quintic"), ([0,0,0],[5,5,5])),
        (_sign("halley"), ([0,0,0],[40,40,40]))
    ],
    INV: [
        (_inv("ns"), ([0,0],[5,5])),
        (_inv("ns_chebyshev"),([0,0,0],[5,5,5])),
    ],
    SQRT: [
        (_sqrt("db"), ([0,0],[50,50])),
        (_sqrt("nsv"), ([0,0],[5,5])),
        (_sqrt("visser"), ([0,0],[10,10])),
        (_sqrt("visser_coupled"), ([0,0],[10,10])),
        #(_sqrt("newton"), ([0,0],[10,10])),
        #(_sqrt("newton_coupled"), ([0,0],[10,10])),
        (_sqrt("couple"), None)
    ],
    PROOT: [
        (_proot("newton"), ([0,0], [10,10])),
        (_proot("visser"), ([0,0], [10,10])),
        (_proot("couple"), None),
        #(_proot("swap"), None),
        (_proot("iannazzo"), ([0,0], [10,10]))
    ],
}

mcts_config = dict(
    c_ucb=5,
    alpha_pw=0.3,
    #
    epsilon = 1e-6,
    EXPLORE_K = EXPLORE_K,
    early_termination_epsilon=1e-5,
    #
    budget=int(3e5),
    # budget=int(3e3),
    print_every=int(1e3),
    max_termination_count=10, ## INF
    tree_initial_capacity=int(1e4),
    device=device,
    actions=action_specs[matrix_function],
    initialize_with_baselines=True,
)

algorithm_name = "mcts"
#
# algorithm_name = _sign("newton")
# algorithm_name = _sign("ns")
# algorithm_name = _sign("scaled_newton")
# algorithm_name = _sign("scaled_ns")
# algorithm_name = _sign("halley")
# #
# algorithm_name = _inv("ns")
# algorithm_name = _inv("ns_chebyshev")
# #
# algorithm_name = _sqrt("db")
# algorithm_name = _sqrt("scaled_db")
# algorithm_name = _sqrt("nsv")
# algorithm_name = _sqrt("binom")
# algorithm_name = _sqrt("visser")
# algorithm_name = _sqrt("newton")


if MCTS in algorithm_name:
    algorithm_config = mcts_config
else:
    algorithm_config = DEFAULT_BASELINE_CONFIGS[algorithm_name]

config = {
    "matrix_function": matrix_function,
    #
    "algorithm_name": algorithm_name,
    "algorithm_config": algorithm_config,
    #
    "init_mat_name": init_mat_name,
    "init_mat_config": init_mat_config,
    #
    "commit_hash": get_commit_hash(repo),
    "make_algorithm_hash": get_tree_hash(repo, "make_algorithm"),
    "time_str": get_time_str(),
    "dirty": is_dirty(get_repo()),
    "experiment_name": "make_algorithm",
    #
    "precision": precision,
    "custom_loss": custom_loss,
    "device": device,
    #
    "seed": 42,
}
