import os
import jax
import jax.numpy as jnp
from flax.metrics import tensorboard

def make_summary_writer(logdir, config):
    """Creates summary writer."""
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tensorboard.SummaryWriter(logdir)
    summary_writer.hparams(config)
    return summary_writer

def generate_perms(_key1, _key2, x, steps_per_epoch, batch_size, canary_index):
    """Generates a list of permutations for a single epoch ensuring that the canary is in every epoch"""
    random_location = jax.random.choice(_key1, jnp.arange(steps_per_epoch * batch_size))
    indices = jax.random.permutation(_key2, x).at[random_location].set(canary_index)
    return indices[:steps_per_epoch * batch_size].reshape(steps_per_epoch, batch_size)
