
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import chex
from flax import struct
import jax
from jax import lax
import jax.numpy as jnp
import int_box
from gymnax.environments import spaces
import environment
from utils import change_max_relator_length_of_presentation
from ac_moves import setup_actions, setup_prime_actions
from ast import literal_eval

@struct.dataclass
class EnvState(environment.EnvState):
    x: jnp.array # presentation
    idx: int
    time: int

@struct.dataclass
class EnvParams(environment.EnvParams):
    n_gen: int = 2
    max_length: int = 64
    max_steps_in_episode: int = 200  

class AC(environment.Environment):

    def __init__(self, n_gen=2, max_length=64, max_steps_in_episode=200, 
                 primed_actions=False, is_reward_sparse=False,
                 initial_states_file="all_presentations"):
        super().__init__()
        # NOTE: My current knowledge: it is okay to keep attributes in classes of JAX code,
        # but these attributes must not be updated in any of the methods.
        # https://docs.jax.dev/en/latest/stateful-computations.html
        # Classes in JAX should be thought of as NameSpaces that store functions with common 
        # context as methods. Similarly attributes could be thought of as variables in this namespace.
        self.params = EnvParams(n_gen=n_gen, max_length=max_length, max_steps_in_episode=max_steps_in_episode)
        self._actions = setup_prime_actions(self.params) if primed_actions else setup_actions(self.params)
        self.init_states = self.initiate_states(self.params, initial_states_file)
        self.reward_fn = self.get_reward_fn(is_reward_sparse)
        print(f"Loading {len(self.init_states)} states..")

    def initiate_states(self, params: EnvParams, initial_states_file: str):
        with open(f"data/{initial_states_file}.txt", "r") as file:
            initial_states = [literal_eval(line.strip()) for line in file]
        initial_states = [change_max_relator_length_of_presentation(state, params.max_length) for state in initial_states]  
        return jnp.array(initial_states)

    def get_reward_fn(self, is_sparse: bool):
        if is_sparse:
            return lambda x, terminated: jnp.array(terminated, int)
        else:
            print(f"Using dense reward function. You may want to clip and normalize the reward.")
            return lambda x, terminated: -jnp.clip(jnp.count_nonzero(x), 0, 10) * (1-terminated) + 1000 * terminated

    @property
    def default_params(self) -> EnvParams:
        # NOTE: this is different from the usual behavior of gymnax where default params
        # return EnvParams(). We keep it different here as the env.__init__ method needs to
        # initiate _actions which depend on params.n_gen. Hence, we define a params inside the 
        # __init__ method.
        return self.params
    
    def step_env(
        self,
        key: chex.PRNGKey,
        state: EnvState,
        action: Union[int, float, chex.Array],
        params: EnvParams,
    ) -> Tuple[chex.Array, EnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
        """Performs step transitions in the environment."""

        state = EnvState(
            x=jax.lax.switch(action, self._actions, state.x),
            idx=state.idx,
            time=state.time + 1,
        )
        terminated = jnp.count_nonzero(state.x) == params.n_gen 
        truncated = state.time >= params.max_steps_in_episode
        done = jnp.logical_or(terminated, truncated)
        reward = self.reward_fn(state.x, terminated)

        return ( 
            lax.stop_gradient(self.get_obs(state)),
            lax.stop_gradient(state),
            jnp.array(reward),
            done,
            {"discount": self.discount(state, params),
             "terminated": terminated,
             "truncated": truncated,
             "idx": state.idx,
             "length": jnp.count_nonzero(state.x)},
        ) # TODO: what is the purpose of discount?
        # environments.Environment.discount(...) returns 1 if the episode is terminated, else 0.

    def is_terminal(self, state: EnvState, params: EnvParams) -> jnp.ndarray:
        # TODO: should I remove this function?
        """Check whether state transition is terminal."""
        terminated = jnp.count_nonzero(state.x) == params.n_gen
        truncated = state.time >= params.max_steps_in_episode
        done = jnp.logical_or(terminated, truncated)
        return done

    def reset_env(
        self, key: chex.PRNGKey, params: EnvParams, idx: int
    ) -> Tuple[chex.Array, EnvState]:
        """Performs resetting of environment."""
        state = EnvState(x=self.init_states[idx], idx=idx, time=0)
        return self.get_obs(state), state

    def get_obs(self, state: EnvState, params=None, key=None) -> chex.Array:
        """Applies observation function to state."""
        return state.x
    
    @property
    def name(self) -> str:
        """Environment name."""
        return "AC-v0"
    
    @property
    def num_actions(self) -> int:
        """Number of actions possible in environment."""
        # NOTE: ideally, num_actions should take params as an arg but
        # environments.Environment.num_actions does not take params.
        return 12 # TODO: this should be 3 * params.n_gen**2
    
    def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
        """Action space of the environment."""
        # NOTE: in environments.Environment.action_space, params is optional, 
        # but for us, it is not. Maybe I can write all the functions so that the default value
        # is for n_gen=2.
        return spaces.Discrete(3 * params.n_gen**2)
    
    def observation_space(self, params: EnvParams) -> spaces.Box:
        """Observation space of the environment."""
        # NOTE: it's incorrect to use gymnax.spaces.Box for our environment
        # as it seems more appropriate for floats.
        # Specifically, we will want to use random.randint instead of
        # random.uniform. May also have to set high to have values n_gen + 1
        # instead of n_gen as maxval in random.randint args is exclusive.
        arr_len = params.max_length * params.n_gen
        low = jnp.ones(arr_len, dtype=jnp.int8) * (-params.n_gen) # inclusive
        high = np.ones(arr_len, dtype=jnp.int8) * (params.n_gen) # inclusive
        return int_box.IntBox(low, high, shape=(arr_len,), dtype=jnp.int8)