from jax import numpy as jnp
from fairgym.envs.state import create_state

# Image
# shape is (num_groups, feat_bins, 4) as uint8
#
# red channel is pr_X is Pr(X | G)
# green channel is pr_Y1gX is Pr(Y=1 | X, G)


def encode_obs(state):
    """
    encode current_dist -> image
    """

    pr_G, pr_X, pr_Y1gX = state.pr_G, state.pr_X, state.pr_Y1gX

    pr_G = pr_G / jnp.max(pr_G)

    num_groups, feat_bins = pr_X.shape
    z = pr_G.reshape((num_groups, 1)) * jnp.ones((1, feat_bins))

    return (jnp.stack([pr_X, pr_Y1gX, z, z], axis=-1) * 255).astype(jnp.uint8)


def decode_obs(observation):
    """
    decode image -> current_dist
    """

    # inverts order of indices -> (4, feat_bins, num_groups)
    pr_X, pr_Y1gX, pr_G, _ = observation.T

    feat_bins, num_groups = pr_X.shape

    # transpose again -> num_groups, feat_bins
    pr_G = (pr_G.T / 256.0)[:, 0]
    pr_X = pr_X.T / 256.0
    pr_Y1gX = pr_Y1gX.T / 256.0

    # address error from lack of resolution by renormalizing
    pr_G = pr_G / jnp.sum(pr_G)
    pr_X = pr_X / jnp.sum(pr_X, axis=1).reshape((num_groups, 1))

    pr_Y1gX = jnp.clip(pr_Y1gX, 0, 1)

    return create_state(pr_G, pr_X, pr_Y1gX)
