import argparse
import numpy as np
import jax.numpy as jnp

from hierarchical_pous.config_setup import parse_args, init_distrbutive_env
from hierarchical_pous.visualization import plot_all_metrics
from hierarchical_pous.visualization import snapshot_model_state


from save import *




# Training Parameters
def get_training_params(args):
    training_params = {
        "lr": {"coarse": 1e-3, "fine": 1e-3},
        "nIter": args.outer_iters,
        "log_interval": 100,
        "unjit_training_step" : args.unjit_training_step
    }
    return training_params


def set_dirs(args):
    args.save_state_dir = None #f"{args.results_dir_name}/state"
    args.load_state_dir = None #"results/exp1_state"
    args.metrics_file   = None #f"{args.results_dir_name}/data.pkl"
    args.log_file       = f"{args.results_dir_name}/run.log"


# Problem Parameters
def get_problem_params(args):
    problem_params = {'dim': 1 , 'uniform_points' : True,
                      'n_samples_per_dim': 100, "gaussian_peak" : {"mean" : 0.5, "std" : 0.05}}
    return problem_params


# Model Parameters
def get_model_setup_params(args):
    import jax
    import optax

    precision = jax.numpy.float64
    param_precision = precision


    net_setup_params = {
        "gating": {
            "coarse": {"n_hidden": args.gate_n_hidden, "n_layers": 0, "num_partitions": args.Pc,
                       "activation": jax.nn.tanh, "embedding": lambda x: x, "dtype": precision, "dtype": param_precision},
            "fine":   {"n_hidden": args.gate_n_hidden, "n_layers": 0, "num_partitions": args.Pf,
                       "activation": jax.nn.tanh, "embedding": lambda x: x, "dtype": precision, "dtype": param_precision},
        },
        "poly": {
            "coarse": {"basis_choice": "mlp", "n_hidden": args.poly_n_hidden,
                       "n_layers": 1, "basis_size": args.poly_basis_size, "num_partitions": args.Pc, "dtype": param_precision},
            "fine":   {"basis_choice": "mlp", "n_hidden": args.poly_n_hidden,
                       "n_layers": 1, "basis_size": args.poly_basis_size, "num_partitions": args.Pf, "dtype": param_precision}
        },
        "coef_slv_params": {
            'slv_type': args.solver_type,
            'max_iter': args.slv_max_iters,
            'tol':  args.solver_tol, #1e-12,
            'reg': args.solver_reg, # 1e-4,
            'omega': args.solver_omega, #1.0 # Relaxation factor for iterative solvers
        },
        "sigma_schedule": {
            "coop": optax.constant_schedule(jax.numpy.array( args.sigma_coop, dtype=precision)),
            "comp": optax.constant_schedule(jax.numpy.array( args.sigma_comp, dtype=precision)),
        },
        'dtype' : precision
    }

    return net_setup_params



def main_experiment(args):
    set_dirs(args)


    # This call HAS to happen before any other call to JAX, otherwise
    # jax.distributed.initialize will not initialize correctly and the code will
    # default to serial execution mode.
    init_var = init_distrbutive_env(args)


    net_setup_params_getter = get_model_setup_params
    training_params_getter = get_training_params
    problem_params_getter = get_problem_params

    from hierarchical_pous.training_runner import run_training_session
    model = run_training_session(
        init_var,
        runtime_args=args,
        net_setup_params_getter=net_setup_params_getter,
        training_params_getter=training_params_getter,
        problem_params_getter=problem_params_getter,
        base_seed=42
    )

    return model



args = argparse.Namespace(
    # Cluster and backend
    cluster_detection_method='auto',
    coordinator_address='localhost:1234',
    backend='cpu',

    # JAX config
    debug=False,
    disable_jit=False,
    profile=False,
    use_float64=True,

    # Training config
    track_solver_iters=True,
    solver_type='iterative_block',
    slv_max_iters=10_000,
    solver_tol=1e-12,
    solver_reg=1e-4,
    solver_omega=1,

    Pc=3,
    Pf=2,
    gate_n_hidden=30,
    poly_n_hidden=30,
    poly_basis_size=30,
    outer_iters=0, #1_000,
    unjit_training_step=False,
    sigma_comp=1.0,
    sigma_coop=1e6,

    bench_mark_runs=0,

    # Output config
    log_file=None,
    metrics_file=None,
    results_dir_name="results/test_run",
    save_state_dir=None,
    load_state_dir=None
)



snapshots = []

model = main_experiment(args)
snapshots.append(snapshot_model_state(model))  # initial

model.training_params["nIter"] = 1_000
model.train(gpu=False, track_solver_iters=True)
snapshots.append(snapshot_model_state(model))

model.training_params["nIter"] = 10_000 - 1_000
model.train(gpu=False, track_solver_iters=True)
snapshots.append(snapshot_model_state(model))


model.training_params["nIter"] = 50_000  - (10_000 - 1_000) - 1_000
model.train(gpu=False, track_solver_iters=True)
snapshots.append(snapshot_model_state(model))


n_iters = [0, 1_000, 10_000, 50_000]
filepath = "snapshots.pkl"
save_snapshots_to_file(snapshots, n_iters, filepath)
