import numpy as np
from sorcerun.git_utils import (
    is_dirty,
    get_repo,
    get_commit_hash,
    get_time_str,
    get_tree_hash,
)

import sys

repo = get_repo()
ROOT = repo.working_dir
if ROOT not in sys.path:
    sys.path.append(ROOT)

from make_algorithm.baselines import ADAPTIVE_BASELINE_CONFIGS, DEFAULT_BASELINE_CONFIGS
from globals import SIGN, SQRT, INV, PROOT, MCTS, _sign, _sqrt, _inv, MATRIX_FUNCTIONS


# %%
def name_to_type(name):
    for m in MATRIX_FUNCTIONS:
        if name.startswith(m):
            return m

    raise ValueError(f"Can't extract matrix type from {name}")


# %%
commit_hash = get_commit_hash(repo)
time_str = get_time_str()
dirty = is_dirty(repo)
grid_id = f"{time_str}--{commit_hash}--dirty={dirty}"

from make_algorithm.make_algo_config import (
    mcts_config,
    config,
    init_mat_config,
)


def make_config(
    algorithm_name,
    algorithm_config,
    init_mat_name,
    init_mat_size,
    repeat,
):
    matrix_function = None
    try:
        matrix_function = name_to_type(algorithm_name)
    except ValueError:  # must be MCTS
        assert algorithm_name == MCTS
        actions = algorithm_config["actions"]
        # get the first action that is not "Initialization"
        for action, theta_bounds in actions:
            if action != "Initialization":
                matrix_function = name_to_type(action)
                break

    if matrix_function is None:
        raise ValueError(f"Can't extract matrix type from {algorithm_name}")

    c = config.copy()
    ic = init_mat_config.copy()
    ic.update({"d": init_mat_size})

    c.update(
        {
            "matrix_function": matrix_function,
            #
            "init_mat_name": init_mat_name,
            "init_mat_config": ic,
            #
            "algorithm_name": algorithm_name,
            "algorithm_config": algorithm_config,
            #
            "commit_hash": commit_hash,
            "time_str": time_str,
            "dirty": dirty,
            "grid_id": grid_id,
            #
            "repeat": repeat,
        }
    )
    return c


# %% initial matrix distributions
# init_mat_names = sorted(list(MATRIX_DISTRIBUTIONS.keys()))
repeats = 1

init_mat_names = [
    # "CIFAR",
    # "lm_head",
    # "unif",
    "wishart",
    # "wishart_unif",
]
init_mat_sizes = [
    # 500,
    # 2000,
    5000,
]


baseline_only_configs = [
    make_config(
        algorithm_name=name,
        algorithm_config={**conf, "size": init_mat_config["d"], "device": config["device"]},
        init_mat_name=config["init_mat_name"],
        init_mat_size=init_mat_config["d"],
        repeat=1,
    )
    for name, conf in DEFAULT_BASELINE_CONFIGS.items()
]
print(f"Number of baseline only configs: {len(baseline_only_configs)}")

adaptive_baseline_only_configs = [
    make_config(
        algorithm_name=name,
        algorithm_config={**conf, "size": init_mat_config["d"], "device": config["device"]},
        init_mat_name=init_mat_name,
        init_mat_size=init_mat_size,
        repeat=r,
    )
    for name, conf in ADAPTIVE_BASELINE_CONFIGS.items()
    for init_mat_name in init_mat_names
    for init_mat_size in (init_mat_sizes if init_mat_name != "CIFAR" else [3072])
    for r in range(repeats)
]
print(
    f"Number of adaptive baseline only configs: {len(adaptive_baseline_only_configs)}"
)

mcts_only_configs = [
    make_config(
        algorithm_name=MCTS,
        algorithm_config=algorithm_config,
        init_mat_name=init_mat_name,
        init_mat_size=init_mat_size,
        repeat=r,
    )
    for algorithm_config in [mcts_config]
    for init_mat_name in init_mat_names
    for init_mat_size in (init_mat_sizes if init_mat_name != "CIFAR" else [3072])
    for r in range(repeats)
]


print(f"Number of MCTS only configs: {len(mcts_only_configs)}")

configs = baseline_only_configs + adaptive_baseline_only_configs# + mcts_only_configs
print(set([c["matrix_function"] for c in configs]))

seeds = 42 + np.arange(len(configs))
configs = [
    {
        **c,
        "seed": int(seed),
    }
    for c, seed in zip(configs, seeds) if c["matrix_function"] == SQRT
]

# %%
print(f"Number of configs in make_algo_grid_config: {len(configs)}")
