"""
Inspired by the implementation of PGX envs: https://github.com/sotetsuk/pgx
"""

from abc import ABC

import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from jax import Array, lax

from medium_rl.policy import random_action

FALSE = jnp.bool_(False)
TRUE = jnp.bool_(True)


@struct.dataclass
class State:
    obs: Array = jnp.zeros(1, dtype=jnp.int32)  # Store sequence here
    legal_action_mask: Array = jnp.ones(1, dtype=jnp.int32)
    terminating: Array = FALSE
    terminated: Array = FALSE
    timestep: Array = jnp.int32(0)  # Current step of env (starts at 0)


class SequenceEnv(ABC):
    """
    Abstract base class for sequence-based environments.
    """
    obs_dtype = jnp.int32

    # To be defined by child classes
    num_tokens = 1
    alphabet = []
    dict = {}

    CLS = 0
    PAD = 1
    EOS = 2

    def __init__(self, min_len: int, max_len: int):
        self.min_len = min_len
        self.max_len = max_len

        self.legal_action_mask = jnp.ones(self.num_tokens, dtype=jnp.int32)

        # Can't output CLS, PAD but can output EOS
        self.legal_action_mask = self.legal_action_mask.at[self.CLS].set(0)
        self.legal_action_mask = self.legal_action_mask.at[self.PAD].set(0)

        # Can't output EOS
        self.no_eos_legal_action_mask = self.legal_action_mask.at[self.EOS].set(0)

        # Can only output EOS
        self.eos_legal_action_mask = jnp.zeros(self.num_tokens, dtype=jnp.int32).at[self.EOS].set(1)

        self.init_seq = self.PAD * jnp.ones(self.max_len, dtype=jnp.int32)
        self.init_seq = self.init_seq.at[0].set(self.CLS)
        self.reset_fn = jax.jit(jax.vmap(self.init))
        self.step_fn = jax.jit(jax.vmap(self.step))

    def init(self, *args, **kwargs) -> State:
        state = State(
            obs=self.init_seq,
            legal_action_mask=self.no_eos_legal_action_mask,
            timestep=jnp.int32(0),
        )
        return state

    def step(self, state: State, action: Array, **kwargs) -> State:
        curr_token_idx = state.timestep + 1  # +1 for BOS

        # Avoid returning True when the state has already terminated bu the action was still EOS
        terminating = lax.cond(action == self.EOS, lambda: ~state.terminated, lambda: False)
        terminated = lax.cond(
            state.terminating, lambda: True, lambda: state.terminated
        )  # If previous state was terminating, state is now terminated

        # Update sequence/observation
        seq = state.obs
        is_legal = state.legal_action_mask[action]

        # Play first legal action if illegal action is selected
        action = lax.cond(is_legal, lambda: action, lambda: state.legal_action_mask.argmax())

        seq = lax.cond(
            terminated, lambda: seq, lambda: seq.at[curr_token_idx].set(action)
        )  # Update only if not terminated

        # Update legal action mask/termination status
        legal_action_mask = lax.cond(
            curr_token_idx > self.min_len,  # If past min_len, allow termination
            lambda: self.legal_action_mask,
            lambda: self.no_eos_legal_action_mask,
        )
        legal_action_mask = lax.cond(
            curr_token_idx + 1 == self.max_len - 1,  # If next idx would be last token, must terminate
            lambda: self.eos_legal_action_mask,
            lambda: legal_action_mask,
        )

        state = state.replace(
            obs=seq,
            legal_action_mask=legal_action_mask,
            terminating=terminating,
            terminated=terminated,
            timestep=state.timestep + 1,
        )
        return state

    def vectorized_token_idx_to_str(self, token_seqs: Array):
        return ["".join([self.alphabet[idx] for idx in seq if idx > 2]) for seq in np.asarray(token_seqs)]

    def get_rewards(
        self,
        token_seq: Array,  # [B, T], batch of sequence of tokens
    ):
        pass


def test_env(env: SequenceEnv):
    step_fn = env.step_fn

    # Init
    rng = jax.random.PRNGKey(0)
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, 4)
    env_state = env.reset_fn(reset_rng)

    print("Init")
    print(env_state.obs)
    print(env_state.legal_action_mask)
    print(env_state.timestep, "\n")

    i = 0
    while not jnp.all(env_state.terminated):
        rng, action_rng = jax.random.split(rng)
        action = random_action(env_state.legal_action_mask, action_rng)
        env_state = step_fn(env_state, action)

        if i == 0:
            print("After first step")
            print(env_state.obs)
            print(env_state.legal_action_mask)
            print(env_state.timestep, "\n")

        i += 1
        if i > 100:
            break

    print("Final")
    print(env_state.obs)
    print(env_state.legal_action_mask)
    print(env_state.timestep, "\n")

    print("Rewards")
    print(env.vectorized_token_idx_to_str(env_state.obs))
    print(env.get_rewards(env_state.obs))
