from functools import partial
import jax
import jax.numpy as jnp
import optax

from flax.training.train_state import TrainState


def list_to_string(hidden_layers):
    run_name = ""
    for model in hidden_layers:
        for layer in model:
            run_name += str(layer)
            run_name += "-"
        run_name = run_name[:-1]
        run_name += "_"
    run_name = run_name[:-1]
    return run_name


def string_to_list(hidden_layers):
    models = []
    for elem in hidden_layers.split("_"):
        models.append([int(neurons) for neurons in elem.split("-")])
    return models


def get_run_name(
    hidden_layers, criterion, target_update_frequency, min_steps_per_epoch, n_epochs
):
    run_name = list_to_string(hidden_layers)
    run_name += f"_c{criterion}"
    run_name += f"_tuf{target_update_frequency}"
    run_name += f"_mspe{min_steps_per_epoch}"
    run_name += f"_ne{n_epochs}"
    return run_name


@partial(jax.jit, static_argnames=["apply_fn"])
def loss_fn(params, apply_fn, states, actions, targets):
    preds = apply_fn(params, states, actions)
    loss = jnp.mean(optax.l2_loss(preds, targets))
    return loss


def create_train_step(m_key, model, lr, sample):
    tx = optax.adam(learning_rate=lr, eps=1.5e-4)
    params = model.init(m_key, sample)
    train_state = TrainState.create(
        apply_fn=model.__apply_with_action__, params=params, tx=tx
    )

    @jax.jit
    def train_step(_train_state, _states, _actions, _targets):
        loss, grads = jax.value_and_grad(loss_fn)(
            _train_state.params, _train_state.apply_fn, _states, _actions, _targets
        )
        _train_state = _train_state.apply_gradients(grads=grads)

        return _train_state, loss

    return train_step, train_state
