"""Maze environment as a `gym.Env`."""
from typing import TypeVar
from typing import Callable
from typing import Union
from trajectory_dynamics import State
from trajectory_dynamics import get_trajectory
from tasks import Goal
from tasks import TaskImpulse
import stable_baselines3.common.base_class
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import EveryNTimesteps
from stable_baselines3.common.callbacks import BaseCallback
import numpy.typing as npt
import gym
import gym.spaces
import numpy as np
import torch
import enum
import time

# https://stackoverflow.com/questions/52026652/openblas-blas-thread-init-pthread-create-resource-temporarily-unavailable
import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'


SB3Algorithm = TypeVar(
    "SB3Algorithm",
    bound=stable_baselines3.common.base_class.BaseAlgorithm
)


def get_coordinate_scale(state: State) -> float:
    """Get a scalar factor to scale coordinate to obtain observations from
    states."""
    # Extract the largest coordinate among the segments in the state
    coordinate_values = [
        [
            *segment.p1,
            *segment.p2,
        ]
        for segment in state.segments
    ]
    max_coordinate = torch.tensor(
        coordinate_values
    ).flatten().abs().max().item()
    return 1/max_coordinate


def get_observation(
        state: State,
        t: int,
        T: int,
        coordinate_scale: float,
        goal: Goal,
        ) -> npt.NDArray[np.float32]:
    """Convert the state to a numpy array with elements between -1 and 1."""
    # Normalize relative goal position
    relative_goal_position = (goal.position-state.marble.position)
    if relative_goal_position.norm() > 1e-6:
        relative_goal_position = relative_goal_position/relative_goal_position.norm()
    goal_distance = ((goal.position-state.marble.position)).norm()*coordinate_scale
    values = [
        # Normalized relative target position
        relative_goal_position[0],
        relative_goal_position[1],
        # Normalized target distance
        goal_distance,
        # Absolute velocity
        state.marble.position[0]*coordinate_scale,
        state.marble.position[1]*coordinate_scale,
        state.marble.velocity[0]*coordinate_scale,
        state.marble.velocity[1]*coordinate_scale,
        # Time
        t/T,
    ]
    return np.array(values, dtype=np.float32).clip(min=-10, max=10)


def get_assisted_observation(
        state: State,
        t: int,
        T: int,
        coordinate_scale: float,
        goal: Goal,
        checkpoints: tuple[tuple[float, float], ...],
        target_checkpoint_i: int,
        ) -> npt.NDArray[np.float32]:
    """Convert the state to a numpy array with elements between -1 and 1."""
    observation = list(get_observation(state, t, T, coordinate_scale, goal))
    target_checkpoint = torch.tensor(checkpoints[target_checkpoint_i])
    relative_target_checkpoint = target_checkpoint-state.marble.position
    checkpoint_distance = relative_target_checkpoint.norm()*coordinate_scale
    if relative_target_checkpoint.norm() > 1e-6:
        relative_target_checkpoint = relative_target_checkpoint/relative_target_checkpoint.norm()
    assisted_observation = [
        # Normalized relative target checkpoint position
        relative_target_checkpoint[0],
        relative_target_checkpoint[1],
        # Normalized target checkpoint distance
        checkpoint_distance,
    ]
    observation.extend(assisted_observation)
    return np.array(observation, dtype=np.float32).clip(min=-10, max=10)


def get_reward(state: State, next_state: State, goal: Goal) -> float:
    """Return the reward for the given transition of states."""
    old_distance = (state.marble.position-goal.position).norm().item()
    new_distance = (next_state.marble.position-goal.position).norm().item()
    return old_distance-new_distance


def get_eval_reward(state: State, goal: Goal) -> float:
    """Return a positive number if and only if the state reached the goal."""
    distance = (state.marble.position-goal.position).norm().item()
    return goal.radius-distance


def is_done(state: State, goal: Goal, t: int, T: int) -> bool:
    """Return whether the episode finished."""
    if t >= T:
        return True
    return get_eval_reward(state, goal) > 0


def get_assisted_reward(
        action: np.ndarray,
        state: State,
        next_state: State,
        checkpoints: tuple[tuple[float, float], ...],
        target_checkpoint_i: int,
        goal: Goal,
        coordinate_scale: float,
        ) -> float:
    """Return a reward based on the target checkpoint."""
    # Inspired by PRIMAL: Pathfinding via Reinforcement and Imitation Learning
    if torch.tensor(action).norm() < 0.01:
        return -0.5
    target_checkpoint = checkpoints[target_checkpoint_i]
    position = next_state.marble.position
    distance = (position-torch.tensor(target_checkpoint)).norm().item()
    if distance < goal.radius:
        return 20.0
    return -0.3 - distance


class MazeEnv(gym.Env):
    """Maze task as a `gym.Env`."""

    def __init__(
            self,
            initial_state: State,
            restart_timestep_n: int,  # steps before resetting to initial state
            goal: Goal,
            ):
        self.initial_state = initial_state
        self.restart_timestep_n = restart_timestep_n
        self.action_space = gym.spaces.Box(
            low=-1.0,
            high=1.0,
            shape=(2,),
            dtype=np.float32,
        )
        self.coordinate_scale = get_coordinate_scale(initial_state)
        self.current_t = 0
        self.current_state = initial_state
        self.goal = goal
        self.observation_space = gym.spaces.Box(
            low=-10.0,
            high=10.0,
            shape=self.get_current_observation().shape,
            dtype=np.float32,
        )

    def get_current_observation(self):
        return get_observation(
            state=self.current_state,
            t=0,
            T=self.restart_timestep_n,
            coordinate_scale=self.coordinate_scale,
            goal=self.goal,
        )

    def reset(self):
        self.current_t = 0
        self.current_state = self.initial_state
        return self.get_current_observation()

    def step(self, action: np.ndarray):
        next_state = get_trajectory(
            actions=torch.tensor(action, dtype=torch.float).unsqueeze(0),
            state=self.current_state,
        )[-1]
        observation = self.get_current_observation()
        reward = get_reward(
            state=self.current_state,
            next_state=next_state,
            goal=self.goal,
        )
        done = is_done(
            state=self.current_state,
            goal=self.goal,
            t=self.current_t,
            T=self.restart_timestep_n,
        )
        info = dict()
        self.current_state = next_state
        self.current_t += 1
        return observation, reward, done, info


class MazeEnvEval(MazeEnv):
    """Evaluation version of the Maze environment, where positive rewards
    indicate that the task was completed."""

    def step(self, action: npt.NDArray):
        observation, _, done, info = super().step(action)
        if done:
            reward = get_eval_reward(
                state=self.current_state,
                goal=self.goal,
            )
        else:
            reward = -10
        return observation, reward, done, info


class AssistedMazeEnv(MazeEnv):
    """Assisted version of the Maze environment, where a list of checkpoints
    is provided."""

    def __init__(
            self,
            initial_state: State,
            restart_timestep_n: int,  # steps before resetting to initial state
            goal: Goal,
            checkpoints: tuple[tuple[float, float], ...],
            checkpoint_sat_distance: float,
            ):
        self.checkpoints = checkpoints
        self.target_checkpoint_i = 0
        self.checkpoint_sat_distance = checkpoint_sat_distance
        super().__init__(
            initial_state=initial_state,
            restart_timestep_n=restart_timestep_n,
            goal=goal,
        )

    def get_current_observation(self):
        return get_assisted_observation(
            state=self.current_state,
            t=0,
            T=self.restart_timestep_n,
            coordinate_scale=self.coordinate_scale,
            goal=self.goal,
            checkpoints=self.checkpoints,
            target_checkpoint_i=self.target_checkpoint_i,
        )

    def reset(self):
        self.target_checkpoint_i = 0
        return super().reset()

    def step(self, action: npt.NDArray):
        state = self.current_state
        obs, _, done, info = super().step(action)

        # Compute assisted observation
        observation = get_assisted_observation(
            state=self.current_state,
            t=self.current_t,
            T=self.restart_timestep_n,
            coordinate_scale=self.coordinate_scale,
            goal=self.goal,
            checkpoints=self.checkpoints,
            target_checkpoint_i=self.target_checkpoint_i,
        )

        # Compute assisted reward
        reward = get_assisted_reward(
            action=action,
            state=state,
            next_state=self.current_state,
            checkpoints=self.checkpoints,
            target_checkpoint_i=self.target_checkpoint_i,
            goal=self.goal,
            coordinate_scale=self.coordinate_scale,
        )

        # Update checkpoint
        checkpoint = self.checkpoints[self.target_checkpoint_i]
        position = self.current_state.marble.position
        dist_to_checkpoint = (position-torch.tensor(checkpoint)).norm()
        if dist_to_checkpoint < self.checkpoint_sat_distance:
            self.target_checkpoint_i += 2
            if self.target_checkpoint_i >= len(self.checkpoints):
                self.target_checkpoint_i = len(self.checkpoints)-1
        return observation, reward, done, info


class AssistedMazeEnvEval(AssistedMazeEnv):
    """Evaluation version of the Maze environment, where positive rewards
    indicate that the task was completed."""

    def step(self, action: npt.NDArray):
        observation, _, done, info = super().step(action)
        if done:
            reward = get_eval_reward(
                state=self.current_state,
                goal=self.goal,
            )
        else:
            reward = 0
        return observation, reward, done, info


class WrappingStrategy(enum.Enum):
    Unassisted = enum.auto()
    Assisted = enum.auto()


class TimeoutCallback(BaseCallback):
    def __init__(self, max_time_s: float, verbose=0):
        super(TimeoutCallback, self).__init__(verbose)
        self.max_time_s = max_time_s

    def _on_step(self) -> bool:
        """return: (bool) If the callback returns False, training is aborted
        early."""
        if time.time() > self.max_time_s:
            return False
        return True


def custom_maze_rl_solver(
        task: TaskImpulse,
        restart_timestep_n: int,  # how many times to step before resetting
        train_timestep_n: int,
        get_sb_algorithm: Callable[
            [gym.Env | SubprocVecEnv],
            SB3Algorithm,
            ],
        worker_n: Union[int, None],  # None for no vectorization
        timeout_s: float,
        timeout_check_freq: int,
        verbose: bool,
        ) -> tuple[list[tuple[float, torch.Tensor]], bool]:
    def make_env():
        return AssistedMazeEnv(
            initial_state=task.initial_state,
            restart_timestep_n=restart_timestep_n,
            goal=task.goal_circle,
            checkpoints=task.checkpoints,
            checkpoint_sat_distance=task.goal_circle.radius,
        )

    # Check env
    check_env(make_env())

    # Instantiate environment
    print(f"RL worker_n: {worker_n}")
    env = make_env()\
        if worker_n is None or worker_n == 1\
        else SubprocVecEnv([lambda: make_env()] * worker_n)

    # Instantiate the RL algorithm
    model = get_sb_algorithm(env)

    # Setup a call-back to stop training as soon as the predicate is satisfied.
    # The call-back evaluates the robustness value of the predicate in an
    # entire episode. The policy is executed in deterministic mode.
    eval_env = AssistedMazeEnvEval(
        initial_state=task.initial_state,
        restart_timestep_n=restart_timestep_n,
        goal=task.goal_circle,
        checkpoints=task.checkpoints,
        checkpoint_sat_distance=task.goal_circle.radius,
    )

    # Keep track of current candidate
    log_times = 10
    log = list()
    start_t = time.time()

    # Extract trajectory with the initial policy
    actions = list()
    eval_rewards = list()
    observation = eval_env.reset()
    for _ in range(restart_timestep_n):
        action, _ = model.predict(observation, deterministic=True)
        actions.append(action.tolist())
        observation, eval_reward, done, _ = eval_env.step(action)
        eval_rewards.append(eval_reward)
        if done:
            break

    log.append((time.time()-start_t, torch.tensor(actions)))

    for restart_i in range(log_times):
        print(f"Training run {restart_i}/{log_times}")
        max_time_s = time.time() + timeout_s//log_times
        timeout_callback = EveryNTimesteps(
            n_steps=timeout_check_freq,
            callback=TimeoutCallback(max_time_s=max(max_time_s, 1)),
        )

        # Optimize model
        model.learn(
            total_timesteps=train_timestep_n,
            callback=timeout_callback,
            log_interval=timeout_check_freq,
        )

        # Extract trajectory
        actions = list()
        eval_rewards = list()
        observation = eval_env.reset()
        for _ in range(restart_timestep_n):
            action, _ = model.predict(observation, deterministic=True)
            actions.append(action.tolist())
            observation, eval_reward, done, _ = eval_env.step(action)
            eval_rewards.append(eval_reward)
            if done:
                break

        log.append((time.time()-start_t, torch.tensor(actions)))

        # Verify at least one eval reward is greater than zero
        print(f"Eval reward: {max(eval_rewards)} (goal is > 0)")
        if max(eval_rewards) > 0.0:
            env.close()
            eval_env.close()
            return log, True

    # Clean up environment resources
    env.close()
    eval_env.close()
    return log, False
