# init_state_fn=partial(init_train_state, state_pt, fn),


def init_train_state(
    pt_state: Dict,
    shard_params: Callable,
    batchnorm: bool,
    model: nn.Module,
    optimizer: optax.GradientTransformation,
    init_rng: jax.random.PRNGKey,
    batch: Tuple[jax.Array],
) -> train_state.TrainState:
    init_rng, arg_rng, dropout_rng = jax.random.split(init_rng, 3)
    variables = model.init(arg_rng, *batch, train=False)

    params = variables["params"]
    # shape_m = jax.tree.map(lambda x: x.shape, params)
    # jax.debug.print("Shape model is: {x}", x=shape_m)

    flattened_s = flax.traverse_util.flatten_dict(params, sep=".")
    keys_d = list(flattened_s)
    # model.to_np(pt_state, flattened_s)
    # print("NEW pt state is: ", pt_state)

    # print("Keys: ", keys_d)
    # TODO: in a better world I would use a map
    # for k, v in flattened_s.items():
    #     flattened_s[k] = model.load_torch_pt(pt_state, k, v, shard_params)

    # fn = partial(model.load_torch_pt, shard_params, pt_state)
    # fn = jax.jit(fn, donate_argnums=(0,))
    # flattened_s = fn(flattened_s)  # model.load_torch_pt(, flattened_s)
    variables["params"] = flax.traverse_util.unflatten_dict(flattened_s, sep=".")
    # Create a State
    if batchnorm:
        state = BatchNormTrainState.create(
            apply_fn=model.apply,
            params=variables["params"],
            batch_stats=variables["batch_stats"],
            tx=optimizer,
            key=dropout_rng,
        )
    else:
        state = TrainState.create(
            apply_fn=model.apply,
            tx=optimizer,
            params=variables["params"],
            key=dropout_rng,
        )

    # print(state)
    return state


@staticmethod
def load_torch_pt(shard_params, pt_state, flattened_p):
    for key_s, val_s in flattened_p.items():
        flattened_p[key_s] = load_torch_pt(
            MAPPING_PALIGEMMA, pt_state, key_s, val_s, shard_params
        )
    return flattened_p


@staticmethod
def to_np(pt_state, params):
    for key_s in params:
        if key_s == "lm_head.kernel":
            print("Ce dracu: ", params[key_s])
        tmp, torch_key = apply_trans(MAPPING_PALIGEMMA, key_s, pt_state)
        if tmp is not None:
            start_id = (0,) * len(params[key_s].shape)
            params[key_s] = jax.lax.dynamic_update_slice(
                params[key_s], jnp.asarray(tmp, copy=True), start_id
            )
            # params[key_s] = jnp.zeros_like(params[key_s]) + jnp.array(
            #     tmp, device=params[key_s].device()
            # )
            # params[key_s] = params[key_s]
    return params


def load_torch_pt(MAPPING, pt_state, key_s, val_s, shard_params):
    get_torch_layer = lambda k, nr: k.replace(".{x}.", f".{nr}.") if nr else k
    """Expects flattened np_state and torch pt_state"""
    # transform layer nameing
    layer_nr = None
    if "residual_block_" in key_s:
        res = re.search(r"residual_block_[0-9]+", key_s).group(0)
        layer_nr = res.split("_")[-1]
        mapping_key = key_s.replace(res, "residual_block_{x}")
    else:
        mapping_key = key_s

    if not mapping_key in MAPPING:
        print(f"warning, {key_s} has no pretrained mapping")
        return val_s

    partitioned = False
    if isinstance(val_s, nn.Partitioned):
        partitioned = True

    y = MAPPING[mapping_key]
    # apply transform (if any) in jax
    if isinstance(y, tuple):
        transform = y[1]
        torch_key = get_torch_layer(y[0], layer_nr)
        y = pt_state[torch_key]
        tmp = transform(y)
    # same transform for all Dense layers
    elif do_transpose(key_s):
        if "lm_head.kernel" in key_s:
            print("DO transpose")
        torch_key = get_torch_layer(y, layer_nr)
        tmp = pt_state[torch_key]
        tmp = tmp.T
    else:
        torch_key = get_torch_layer(y, layer_nr)
        tmp = pt_state[torch_key]

    if partitioned:
        tmp = jnp.asarray(tmp, copy=False)
        val_s = copy_array(shard_params, tmp, val_s)
    else:
        # val_s = jnp.asarray(tmp, dtype=val_s.dtype)
        # jax.numpy.from_dlpack(torch.utils.dlpack.to_dlpack(tmp))
        # start_id = (0,) * len(val_s.shape)
        # val_s = jax.lax.dynamic_update_slice(val_s, jnp.asarray(tmp), start_id)
        tmp = jnp.asarray(tmp)
        print(val_s)
        val_s = copy_array(shard_params, tmp, val_s)
    return val_s
