import numpy as np
import jax

from flax.training import orbax_utils
import orbax.checkpoint
from models import uPDNet


def get_model_from_name(rng, model_name='uPDNet', im_size_test=256, im_size=32, channels=3, num_iter=100):

    # Initialisation_test
    mask_test = jax.random.choice(rng, 2, shape=(1, 1, im_size_test, im_size_test), p=np.asarray([0.5, 0.5]))
    if model_name == 'uPDNet':
        model_test = uPDNet(mask=mask_test, im_size=im_size_test, channels=channels, num_iter=num_iter)
    else:
        raise NotImplementedError

    # Initialisation train
    mask = jax.random.choice(rng, 2, shape=(1, 1, im_size, im_size), p=np.asarray([0.5, 0.5]))
    if model_name == 'uPDNet':
        model = uPDNet(mask=mask, im_size=im_size, channels=channels, num_iter=num_iter)
        learning_rate = 1e-3
    else:
        raise NotImplementedError

    return model, model_test, mask, learning_rate


def model_apply(state_model, state_param, input_images, type_operator):
    x_output = state_model.apply_fn({'params': state_param.params}, input_images, type_operator)
    return x_output


def save_model(ckpt, save_path):
    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(ckpt)
    orbax_checkpointer.save(save_path, ckpt, save_args=save_args,
                            force=True)  # force=True overwrites existing files
