"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

from dataclasses import dataclass
from collections import namedtuple, OrderedDict
from functools import partial
from enum import IntEnum

import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from typing import Tuple, Optional
import chex
from flax import struct
from flax.core.frozen_dict import FrozenDict

from .common import EnvInstance, make_maze_map
from minimax.envs import environment, spaces
from minimax.envs.registration import register_ued


class SequentialActions(IntEnum):
    skip = 0
    wall = 1
    goal = 2
    agent = 3


@struct.dataclass
class EnvState:
    encoding: chex.Array
    time: int
    terminal: bool


@struct.dataclass
class EnvParams:
    height: int = 15
    width: int = 15
    n_walls: int = 25
    noise_dim: int = 50
    replace_wall_pos: bool = False
    fixed_n_wall_steps: bool = False
    first_wall_pos_sets_budget: bool = False
    use_seq_actions: bool = False
    set_agent_dir: bool = False
    normalize_obs: bool = False
    singleton_seed: int = -1


class UEDMaze(environment.Environment):
    def __init__(
            self,
            height=13,
            width=13,
            n_walls=25,
            noise_dim=16,
            replace_wall_pos=False,
            fixed_n_wall_steps=False,
            first_wall_pos_sets_budget=False,
            use_seq_actions=False,
            set_agent_dir=False,
            normalize_obs=False,
    ):
        """
        Using the original action space requires ensuring proper handling
        of a sequence with trailing dones, e.g. dones: 0 0 0 0 1 1 1 1 1 ... 1.
        Advantages and value losses should only be computed where ~dones[0].
        """
        assert not (first_wall_pos_sets_budget and fixed_n_wall_steps), \
            'Setting first_wall_pos_sets_budget=True requires fixed_n_wall_steps=False.'

        super().__init__()

        self.n_tiles = height*width
        # go straight, turn left, turn right, take action
        self.action_set = jnp.array(jnp.arange(self.n_tiles))

        self.params = EnvParams(
            height=height,
            width=width,
            n_walls=n_walls,
            noise_dim=noise_dim,
            replace_wall_pos=replace_wall_pos,
            fixed_n_wall_steps=fixed_n_wall_steps,
            first_wall_pos_sets_budget=first_wall_pos_sets_budget,
            use_seq_actions=False,
            set_agent_dir=set_agent_dir,
            normalize_obs=normalize_obs,
        )

    @staticmethod
    def align_kwargs(kwargs, other_kwargs):
        kwargs.update(dict(
            height=other_kwargs['height'],
            width=other_kwargs['width'],
        ))

        return kwargs

    def _add_noise_to_obs(self, rng, obs):
        if self.params.noise_dim > 0:
            noise = jax.random.uniform(rng, (self.params.noise_dim,))
            obs.update(dict(noise=noise))

        return obs

    def reset_env(
            self,
            key: chex.PRNGKey):
        """
        Prepares the environment state for a new design
        from a blank slate. 
        """
        params = self.params
        noise_rng, dir_rng = jax.random.split(key)
        encoding = jnp.zeros((self._get_encoding_dim(),), dtype=jnp.uint32)

        if not params.set_agent_dir:
            rand_dir = jax.random.randint(
                dir_rng, (), minval=0, maxval=4)  # deterministic
            tile_scale_dir = jnp.ceil(
                (rand_dir/4)*self.n_tiles).astype(jnp.uint32)
            encoding = encoding.at[-1].set(tile_scale_dir)

        state = EnvState(
            encoding=encoding,
            time=0,
            terminal=False,
        )

        obs = self._add_noise_to_obs(
            noise_rng,
            self.get_obs(state)
        )

        return obs, state

    def step_env(
            self,
            key: chex.PRNGKey,
            state: EnvState,
            action: int,
    ) -> Tuple[chex.Array, EnvState, float, bool, dict]:
        """
        Take a design step. 
                action: A pos as an int from 0 to (height*width)-1
        """
        params = self.params

        collision_rng, noise_rng = jax.random.split(key)

        # Sample a random free tile in case of a collision
        dist_values = jnp.logical_and(  # True if position taken
            jnp.ones(params.n_walls + 2),
            jnp.arange(params.n_walls + 2)+1 > state.time
        )

        # Get zero-indexed last wall time step
        if params.fixed_n_wall_steps:
            max_n_walls = params.n_walls
            encoding_pos = state.encoding[:params.n_walls+2]
            last_wall_step_idx = max_n_walls - 1
        else:
            max_n_walls = jnp.round(
                params.n_walls*state.encoding[0]/self.n_tiles).astype(jnp.uint32)

            if self.params.first_wall_pos_sets_budget:
                encoding_pos = state.encoding[:params.n_walls+2]
                last_wall_step_idx = jnp.maximum(max_n_walls, 1) - 1
            else:
                encoding_pos = state.encoding[1:params.n_walls+3]
                last_wall_step_idx = max_n_walls

        pos_dist = jnp.ones(self.n_tiles).at[
            jnp.flip(encoding_pos)].set(jnp.flip(dist_values))
        all_pos = jnp.arange(self.n_tiles, dtype=jnp.uint32)

        # Only mark collision if replace_wall_pos=False OR the agent is placed over the goal
        goal_step_idx = last_wall_step_idx + 1
        agent_step_idx = last_wall_step_idx + 2

        # Track whether it is the last time step
        next_state = state.replace(time=state.time + 1)
        done = self.is_terminal(next_state)

        # Always place agent idx in last enc position.
        is_agent_dir_step = jnp.logical_and(
            params.set_agent_dir,
            done
        )

        collision = jnp.logical_and(
            pos_dist[action] < 1,
            jnp.logical_or(
                not params.replace_wall_pos,
                jnp.logical_and(  # agent pos cannot override goal
                    jnp.equal(state.time, agent_step_idx),
                    jnp.equal(state.encoding[goal_step_idx], action)
                )
            )
        )
        collision = (collision * (1-is_agent_dir_step)).astype(jnp.uint32)

        action = (1-collision)*action + \
            collision*jax.random.choice(collision_rng,
                                        all_pos, replace=False, p=pos_dist)

        enc_idx = (1-is_agent_dir_step)*state.time + is_agent_dir_step*(-1)
        encoding = state.encoding.at[enc_idx].set(action)

        next_state = next_state.replace(
            encoding=encoding,
            terminal=done
        )
        reward = 0

        obs = self._add_noise_to_obs(noise_rng, self.get_obs(next_state))

        return (
            lax.stop_gradient(obs),
            lax.stop_gradient(next_state),
            reward,
            done,
            {},
        )

    def get_env_instance(
        self,
        key: chex.PRNGKey,
        state: EnvState
    ) -> chex.Array:
        """
        Converts internal encoding to an instance encoding that 
        can be interpreted by the `set_to_instance` method 
        the paired Environment class.
        """
        params = self.params
        h = params.height
        w = params.width
        enc = state.encoding

        # === Extract agent_dir, agent_pos, and goal_pos ===
        # Num walls placed currently
        if params.fixed_n_wall_steps:
            n_walls = params.n_walls
            enc_len = self._get_encoding_dim()
            wall_pos_idx = jnp.flip(enc[:params.n_walls])
            agent_pos_idx = enc_len-2  # Enc is full length
            goal_pos_idx = enc_len-3
        else:
            n_walls = jnp.round(
                params.n_walls*enc[0]/self.n_tiles
            ).astype(jnp.uint32)
            if params.first_wall_pos_sets_budget:
                # So 0-padding does not override pos=0
                wall_pos_idx = jnp.flip(enc[:params.n_walls])
                enc_len = n_walls + 2  # [wall_pos] + len((goal, agent))
            else:
                wall_pos_idx = jnp.flip(enc[1:params.n_walls+1])
                # [wall_pos] + len((n_walls, goal, agent))
                enc_len = n_walls + 3
            # Positions are relative to n_walls when n_walls is variable.
            agent_pos_idx = enc_len-1
            goal_pos_idx = enc_len-2

        # Get agent + goal info (set agent/goal pos 1-step out of range if they are not yet placed)
        goal_placed = state.time > jnp.array([goal_pos_idx], dtype=jnp.uint32)
        goal_pos = \
            goal_placed*jnp.array([enc[goal_pos_idx] % w, enc[goal_pos_idx]//w], dtype=jnp.uint32) \
            + (~goal_placed)*jnp.array([w, h], dtype=jnp.uint32)

        agent_placed = state.time > jnp.array(
            [agent_pos_idx], dtype=jnp.uint32)
        agent_pos = \
            agent_placed*jnp.array([enc[agent_pos_idx] % w, enc[agent_pos_idx]//w], dtype=jnp.uint32) \
            + (~agent_placed)*jnp.array([w, h], dtype=jnp.uint32)

        agent_dir_idx = jnp.floor((4*enc[-1]/self.n_tiles)).astype(jnp.uint8)

        # Make wall map
        wall_start_time = jnp.logical_and(  # 1 if explicitly predict # blocks, else 0
            not params.fixed_n_wall_steps,
            not params.first_wall_pos_sets_budget
        ).astype(jnp.uint32)
        wall_map = jnp.zeros(h*w, dtype=jnp.bool_)
        wall_values = jnp.arange(
            params.n_walls) + wall_start_time < jnp.minimum(state.time, n_walls + wall_start_time)
        wall_values = jnp.flip(wall_values)
        wall_map = wall_map.at[wall_pos_idx].set(wall_values)

        # Zero out walls where agent and goal reside
        agent_mask = agent_placed * \
            (~(jnp.arange(h*w) ==
             state.encoding[agent_pos_idx])) + ~agent_placed*wall_map
        goal_mask = goal_placed * \
            (~(jnp.arange(h*w) ==
             state.encoding[goal_pos_idx])) + ~goal_placed*wall_map
        wall_map = wall_map*agent_mask*goal_mask
        wall_map = wall_map.reshape(h, w)

        return EnvInstance(
            agent_pos=agent_pos,
            agent_dir_idx=agent_dir_idx,
            goal_pos=goal_pos,
            wall_map=wall_map
        )

    def is_terminal(self, state: EnvState) -> bool:
        done_steps = state.time >= self.max_episode_steps()
        return jnp.logical_or(done_steps, state.terminal)

    def _get_post_terminal_obs(self, state: EnvState):
        dtype = jnp.float32 if self.params.normalize_obs else jnp.uint8
        image = jnp.zeros((
            self.params.height+2, self.params.width+2, 3), dtype=dtype
        )

        return OrderedDict(dict(
            image=image,
            time=state.time,
            noise=jnp.zeros(self.params.noise_dim, dtype=jnp.float32),
        ))

    def get_obs(self, state: EnvState):
        instance = self.get_env_instance(jax.random.PRNGKey(0), state)

        image = make_maze_map(
            self.params,
            instance.wall_map,
            instance.goal_pos,
            instance.agent_pos,
            instance.agent_dir_idx,
            pad_obs=False
        )

        if self.params.normalize_obs:
            image = image/10.0

        return OrderedDict(dict(
            image=image,
            time=state.time,
        ))

    @property
    def default_params(self):
        return EnvParams()

    @property
    def name(self) -> str:
        """Environment name."""
        return "UEDMaze"

    @property
    def num_actions(self) -> int:
        """Number of actions possible in environment."""
        return len(self.action_set)

    def action_space(self) -> spaces.Discrete:
        """Action space of the environment."""
        params = self.params
        return spaces.Discrete(
            params.height*params.width,
            dtype=jnp.uint32
        )

    def observation_space(self) -> spaces.Dict:
        """Observation space of the environment."""
        params = self.params
        max_episode_steps = self.max_episode_steps()
        spaces_dict = {
            'image': spaces.Box(0, 255, (params.height+2, params.width+2, 3)),
            'time': spaces.Discrete(max_episode_steps),
        }
        if self.params.noise_dim > 0:
            spaces_dict.update({
                'noise': spaces.Box(0, 1, (self.params.noise_dim,))
            })
        return spaces.Dict(spaces_dict)

    def state_space(self) -> spaces.Dict:
        """State space of the environment."""
        params = self.params
        encoding_dim = self._get_encoding_dim()
        max_episode_steps = self.max_episode_steps()
        h = params.height
        w = params.width
        return spaces.Dict({
            'encoding': spaces.Box(0, 255, (encoding_dim,)),
            'time': spaces.Discrete(max_episode_steps),
            "terminal": spaces.Discrete(2),
        })

    def _get_encoding_dim(self) -> int:
        encoding_dim = self.max_episode_steps()
        if not self.params.set_agent_dir:
            encoding_dim += 1  # max steps is 1 less than full encoding dim

        return encoding_dim

    def max_episode_steps(self) -> int:
        if self.params.fixed_n_wall_steps \
                or self.params.first_wall_pos_sets_budget:
            max_episode_steps = self.params.n_walls + 2
        else:
            max_episode_steps = self.params.n_walls + 3

        if self.params.set_agent_dir:
            max_episode_steps += 1

        return max_episode_steps


if hasattr(__loader__, 'name'):
    module_path = __loader__.name
elif hasattr(__loader__, 'fullname'):
    module_path = __loader__.fullname

register_ued(env_id='Maze', entry_point=module_path + ':UEDMaze')
