from jax import lax
from gymnax.environments import spaces, environment
from typing import Tuple, Optional
import chex
import os
import sys
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '.'))
sys.path.append(project_root)
from Craftax.craftax.environment_base.environment_bases import EnvironmentNoAutoReset
from Craftax.craftax.craftax_classic.envs.common import compute_score
from Craftax.craftax.craftax_classic.constants import *
from Craftax.craftax.craftax_classic.game_logic import craftax_step, is_game_over
from Craftax.craftax.craftax_classic.envs.craftax_state import (
    EnvState,
    EnvParams,
    StaticEnvParams,
)
from Craftax.craftax.craftax_classic.text_goal import PixelsTextGoal
from Craftax.craftax.craftax_classic.renderer import render_craftax_pixels
from Craftax.craftax.craftax_classic.world_gen import generate_world

class CraftaxClassicPixelsGoalsEnvNoAutoReset(EnvironmentNoAutoReset):
    def __init__(self, static_env_params: StaticEnvParams = None):
        super().__init__()

        if static_env_params is None:
            static_env_params = self.default_static_params()
        self.static_env_params = static_env_params
        self.pixels_goal_generator = PixelsTextGoal()

    @property
    def default_params(self) -> EnvParams:
        return EnvParams()

    @staticmethod
    def default_static_params() -> StaticEnvParams:
        return StaticEnvParams()

    def step_env(
        self, key: chex.PRNGKey, state: EnvState, action: int, params: EnvParams
    ) -> Tuple[chex.Array, EnvState, float, bool, dict]:

        state, reward = craftax_step(key, state, action, params, self.static_env_params)

        done = self.is_terminal(state, params)
        info = compute_score(state, done)
        info["discount"] = self.discount(state, params)

        return (
            lax.stop_gradient(self.get_obs(state)),
            lax.stop_gradient(state),
            reward,
            done,
            info,
        )

    def reset_env(
        self, rng: chex.PRNGKey, params: EnvParams
    ) -> Tuple[chex.Array, EnvState]:
        state = generate_world(rng, params, self.static_env_params)

        return self.get_obs(state), state

    def get_obs(self, state: EnvState) -> chex.Array:
        pixels = render_craftax_pixels(state, BLOCK_PIXEL_SIZE_AGENT) / 255.0

        def callback(state, pixels):
            # print(f"Calling callback with state: {state} and pixels: {pixels}")
            return self.pixels_goal_generator.get_pixels_goal(state, pixels)

        embedding = jax.experimental.io_callback(
            callback,
            jnp.zeros((512,)),  # callback return shape
            state, pixels,
            ordered=True
        )  # embedding = PixelsTextGoal().get_pixels_goal(state, pixels)

        return embedding

    def is_terminal(self, state: EnvState, params: EnvParams) -> bool:
        return is_game_over(state, params)

    @property
    def name(self) -> str:
        return "Craftax-Classic-Pixels-NoAutoReset-v1"

    @property
    def num_actions(self) -> int:
        return 17

    def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
        return spaces.Discrete(17)

    def observation_space(self, params: EnvParams) -> spaces.Box:
        return spaces.Box(
            0.0,
            1.0,
            (512,),
            dtype=jnp.float32,
        )


class CraftaxClassicPixelsGoalsEnv(environment.Environment):
    def __init__(self, static_env_params: StaticEnvParams = None):
        self.state = None
        super().__init__()

        if static_env_params is None:
            static_env_params = self.default_static_params()
        self.static_env_params = static_env_params
        self.pixels_goal_generator = PixelsTextGoal()

    def get_state(self):
        return self.state
    @property
    def default_params(self) -> EnvParams:
        return EnvParams()

    @staticmethod
    def default_static_params() -> StaticEnvParams:
        return StaticEnvParams()

    def step_env(
        self, key: chex.PRNGKey, state: EnvState, action: int, params: EnvParams
    ) -> Tuple[chex.Array, EnvState, float, bool, dict]:

        state, reward = craftax_step(key, state, action, params, self.static_env_params)

        done = self.is_terminal(state, params)
        info = compute_score(state, done)
        info["discount"] = self.discount(state, params)
        self.state = state

        return (
            lax.stop_gradient(self.get_obs(state)),
            lax.stop_gradient(state),
            reward,
            done,
            info,
        )

    def reset_env(
        self, rng: chex.PRNGKey, params: EnvParams
    ) -> Tuple[chex.Array, EnvState]:
        state = generate_world(rng, params, self.static_env_params)
        self.state = state
        return self.get_obs(state), state


    def get_obs(self, state: EnvState) -> chex.Array:
        pixels = render_craftax_pixels(state, BLOCK_PIXEL_SIZE_AGENT) / 255.0

        def callback(state, pixels):
            return self.pixels_goal_generator.get_pixels_goal(state, pixels)

        embedding = jax.pure_callback(
            callback,
            jnp.zeros((512,)),  # callback return shape
            state,pixels,
        )  # embedding = PixelsTextGoal().get_pixels_goal(state, pixels)

        return embedding

    def is_terminal(self, state: EnvState, params: EnvParams) -> bool:
        return is_game_over(state, params)

    @property
    def name(self) -> str:
        return "Craftax-Classic-Pixels-v1"

    @property
    def num_actions(self) -> int:
        return 17

    def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
        return spaces.Discrete(17)

    def observation_space(self, params: EnvParams) -> spaces.Box:
        return spaces.Box(
            0.0,
            1.0,
            (512,),
            dtype=jnp.float32,
        )

