import argparse
import numpy as np
import jax.numpy as jnp

from hierarchical_pous.config_setup import parse_args, init_distrbutive_env
from hierarchical_pous.visualization import snapshot_model_state


# Training Parameters
def get_training_params(args):
    training_params = {
        "lr": {"coarse": 1e-3, "fine": 1e-3},
        "nIter": args.outer_iters,
        "accuracy_bail" : -1,
        "log_interval": 1,
        "unjit_training_step" : args.unjit_training_step
    }
    return training_params



def set_dirs(args):
    args.save_state_dir = None #f"{args.results_dir_name}/state"
    args.load_state_dir = None #"results/exp1_state"
    args.metrics_file   = f"{args.results_dir_name}/data.pkl"
    args.log_file       = f"{args.results_dir_name}/run.log"


# Problem Parameters
def get_problem_params(args):
    problem_params = {'dim': 1 , 'uniform_points' : True,
                      'n_samples_per_dim': 100, "gaussian_peak" : {"mean" : 0.5, "std" : 0.05}}
    return problem_params


# Model Parameters
def get_model_setup_params(args):
    import jax
    import optax

    precision = jax.numpy.float64
    param_precision = precision


    net_setup_params = {
        "gating": {
            "coarse": {"n_hidden": args.gate_n_hidden, "n_layers": 0, "num_partitions": args.Pc,
                       "activation": jax.nn.tanh, "embedding": lambda x: x, "dtype": precision, "dtype": param_precision},
            "fine":   {"n_hidden": args.gate_n_hidden, "n_layers": 0, "num_partitions": args.Pf,
                       "activation": jax.nn.tanh, "embedding": lambda x: x, "dtype": precision, "dtype": param_precision},
        },
        "poly": {
            "coarse": {"basis_choice": "mlp", "n_hidden": args.poly_n_hidden,
                       "n_layers": 1, "basis_size": args.poly_basis_size, "num_partitions": args.Pc, "dtype": param_precision},
            "fine":   {"basis_choice": "mlp", "n_hidden": args.poly_n_hidden,
                       "n_layers": 1, "basis_size": args.poly_basis_size, "num_partitions": args.Pf, "dtype": param_precision}
        },
        "coef_slv_params": {
            'slv_type': args.solver_type,
            'max_iter': args.slv_max_iters,
            'tol':  args.solver_tol, #1e-12,
            'reg': args.solver_reg, # 1e-4,
            'omega': args.solver_omega,
        },
        "sigma_schedule": {
            "coop": optax.constant_schedule(jax.numpy.array( args.sigma_coop, dtype=precision)),
            "comp": optax.constant_schedule(jax.numpy.array( args.sigma_comp, dtype=precision)),
        },
        'dtype' : precision
    }

    return net_setup_params



def main_experiment(args):
    set_dirs(args)


    # This call HAS to happen before any other call to JAX, otherwise
    # jax.distributed.initialize will not initialize correctly and the code will
    # default to serial execution mode.
    init_var = init_distrbutive_env(args)


    net_setup_params_getter = get_model_setup_params
    training_params_getter = get_training_params
    problem_params_getter = get_problem_params

    from hierarchical_pous.training_runner import run_training_session
    model = run_training_session(
        init_var,
        runtime_args=args,
        net_setup_params_getter=net_setup_params_getter,
        training_params_getter=training_params_getter,
        problem_params_getter=problem_params_getter,
        base_seed=42
    )

    return model

# sigma_coop values are hard-coded in models.py
for i, sigma_coop in  enumerate([1e5]): #, 1e5, 1e4, 1e3, 1e2 , 50]):

    args = argparse.Namespace(
        # Cluster and backend
        cluster_detection_method='auto',
        coordinator_address='localhost:1234',
        backend='gpu',

        # JAX config
        debug=False,
        disable_jit=False,
        profile=False,
        use_float64=True,

        # Training config
        track_solver_iters=True,
        solver_type='iterative_block_distributed',
        slv_max_iters=2_000,
        solver_tol=1e-12,
        solver_reg=1e-4,
        solver_omega=1,

        Pc=4,
        Pf=2,
        gate_n_hidden=50,
        poly_n_hidden=20,
        poly_basis_size=20,
        outer_iters=5_000,
        unjit_training_step=False,
        sigma_comp=1.0,
        sigma_coop=sigma_coop,

        bench_mark_runs=0,

        # Output config
        log_file=None,
        metrics_file=f"results/data.pkl",
        results_dir_name=f"results/sigma_{i}",
        save_state_dir=None,
        load_state_dir=None
    )


    model = main_experiment(args)

