
# Training Parameters
def get_training_params(args):
    training_params = {
        "lr": {"coarse": 1e-3, "fine": 1e-3},
        "nIter": args.outer_iters,
        "accuracy_bail" : 1e-2, # 1%,
        "log_interval": 100,
        "unjit_training_step" : args.unjit_training_step
    }
    return training_params


# Problem Parameters
def get_problem_params(args):
    dim = args.problem_dim
    problem_params = {'dim': dim,
                      "gb_allowed" : 0.0009,
                      "gaussian_peak" : {"mean" : 0.5, "std" : 0.1}}
    return problem_params


# Model Parameters
def get_model_setup_params(args):
    import jax
    import optax

    precision = jax.numpy.float64
    param_precision = precision


    net_setup_params = {
        "gating": {
            "coarse": {"n_hidden": args.gate_n_hidden, "n_layers": 0, "num_partitions": args.Pc,
                       "activation": jax.nn.tanh, "embedding": lambda x: x, "dtype": param_precision},
            "fine":   {"n_hidden": args.gate_n_hidden, "n_layers": 0, "num_partitions": args.Pf,
                       "activation": jax.nn.tanh, "embedding": lambda x: x, "dtype": param_precision},
        },
        "poly": {
            "coarse": {"basis_choice": "mlp", "n_hidden": args.poly_n_hidden, "dtype" : param_precision,
                       "n_layers": 1, "basis_size": args.poly_basis_size, "num_partitions": args.Pc},
            "fine":   {"basis_choice": "mlp", "n_hidden": args.poly_n_hidden, "dtype" : param_precision,
                       "n_layers": 1, "basis_size": args.poly_basis_size, "num_partitions": args.Pf}
        },
        "coef_slv_params": {
            'slv_type': args.solver_type,
            'max_iter': args.outer_iters,
            'tol': 1e-12,
            'reg': 1e-4,
            'omega': 1.0
        },
        "sigma_schedule": {
            "coop": optax.constant_schedule(jax.numpy.array(1e5, dtype=precision)),
            "comp": optax.constant_schedule(jax.numpy.array(1.0, dtype=precision)),
        },
        'dtype' : precision
    }

    return net_setup_params

def set_dirs(args):
    args.save_state_dir = f"{args.results_dir_name}/state"
    args.load_state_dir = None #"results/exp1_state"
    args.metrics_file   = f"{args.results_dir_name}/data.pkl"
    args.log_file       = f"{args.results_dir_name}/run.log"

