from contextlib import nullcontext

import jax
import jax.numpy as jnp
from gym_classics.envs.abstract.gridworld import Gridworld
import numpy as np


def jax_device_context(force_cpu=False):
    if not force_cpu:
        return nullcontext()
    return jax.default_device(jax.devices('cpu')[0])


def onehot(index, size, dtype):
    x = np.zeros(size, dtype)
    x[index] = 1.0
    return x


def print_vf(v, task):
    env = task.env

    if isinstance(env, Gridworld):
        v_grid = np.zeros(task.dims)
        for s in env.states():
            x, y = env._decode(s)
            v_grid[x,y] = v[s]
        v_grid = v_grid.T
        print(v_grid)

    else:
        print(v)
