"""RL baselines for the button pushing environment. This policy is a
waypoint-following policy. The waypoints are automatically generated with a
simple hand-written routine, and are such that if the ant follows them the
current task will be solved. These waypoints are always given to the policy
(even during evaluation).

The policy is trained in a multi-task fashion, where tasks correspond to
different password configurations.
"""
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.callbacks import EveryNTimesteps
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3 import SAC
from environment import get_simulation
from pylic.predicates import SolverFailedException
from abc import ABC, abstractmethod
from typing import Optional
from pathlib import Path
import numpy as np
import random
import torch
import time
import gym
import gym.spaces


class TimeoutCallback(BaseCallback):
    def __init__(self, max_time_s: float, verbose=0):
        super(TimeoutCallback, self).__init__(verbose)
        self.max_time_s = max_time_s

    def _on_step(self) -> bool:
        """return: (bool) If the callback returns False, training is aborted
        early."""
        if time.time() > self.max_time_s:
            return False
        return True


def distance(p1: tuple[float, float], p2: tuple[float, float]) -> float:
    """Return the distance between the two points."""
    return ((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)**(1/2)


class BaseAntPasswordEnv(gym.Env, ABC):
    """Base ant password button pushing environment."""

    def __init__(
            self,
            target_password: tuple[int, ...],
            button_n: int,
            actuator_n: int,
            episode_timestep_n: int,
            sub_step_s: float,
            password_so_far_encoding_size: int,
            ):
        """
        It must be the case that `button_n <= password_so_far_encoding_size`.
        """
        assert button_n <= password_so_far_encoding_size
        self.password_so_far_encoding_size = password_so_far_encoding_size
        self.target_password = target_password
        self.button_n = button_n
        self.actuator_n = actuator_n
        self.episode_timestep_n = episode_timestep_n
        self.sub_step_s = sub_step_s
        self.simulation = get_simulation(
            animation_fps=None,
            password=self.target_password,
            num_buttons=self.button_n,
            sub_step_s=self.sub_step_s,
        )
        self.current_t = 0
        self.observation_space = gym.spaces.Box(
            low=-10.0,
            high=10.0,
            shape=self.current_observation.shape,
            dtype=np.float32,
        )
        self.action_space = gym.spaces.Box(
            low=-1.0,
            high=1.0,
            shape=(self.actuator_n,),
            dtype=np.float32,
        )
        self.initial_raw_reward = self.current_raw_reward

    def reset(self):
        self.current_t = 0
        self.simulation = get_simulation(
            animation_fps=None,
            password=self.target_password,
            num_buttons=self.button_n,
            sub_step_s=self.sub_step_s,
        )
        return self.current_observation

    def render(self, mode: str = "human"):
        return self.simulation.physics.render()

    @property
    @abstractmethod
    def current_raw_reward(self) -> float:
        raise NotImplementedError()

    @property
    def current_target_position(self) -> tuple[float, float]:
        """Return the optimal target position."""
        if tuple(self.target_password) == tuple(self.password_so_far):
            return self.goal_position

        # At this point we know there is a button that was incorrectly pushed
        # before completing the password

        # Identify where the target password and the current password differ
        i = 0
        while all((
                i < len(self.password_so_far)-1,
                i < len(self.target_password)-1
                )):
            if self.password_so_far[i] != self.target_password[i]:
                break
            i += 1

        next_button_i = self.target_password[i]
        next_button = self.simulation.buttons[next_button_i]
        target_x, target_y = next_button.x, next_button.y
        return (target_x, target_y)

    @property
    def current_reward(self) -> float:
        return (self.current_raw_reward)/abs(2*self.initial_raw_reward)

    @property
    def goal_position(self) -> tuple[float, float]:
        return (self.simulation.goal.x, self.simulation.goal.y)

    @property
    def distance_to_goal(self) -> float:
        pos = torch.tensor(self.simulation.torso_horizontal_position)
        goal = torch.tensor(self.goal_position)
        goal_radius = self.simulation.goal.r
        distance = (pos-goal).norm()-goal_radius
        return max(0.0, distance.item())

    @property
    def is_success(self) -> bool:
        return all((
            self.distance_to_goal <= 0.0,
        ))

    @property
    def is_done(self) -> bool:
        return any((
            self.current_t >= self.episode_timestep_n,
            self.is_success,
        ))

    @property
    def password_so_far(self) -> list[int]:
        return self.simulation.activated_buttons

    @property
    def password_so_far_encoding(self) -> list[float]:
        """Return an encoding of the buttons that have been pressed so
        far. The vector encodes the order of the buttons that have been
        pressed. This is used to make the task markovian."""
        encoding = [0.0 for _ in range(self.password_so_far_encoding_size)]
        for i, button_i in enumerate(self.password_so_far):
            # Each entry encodes the time at which the button was pressed
            encoding[button_i] = 1.0
        return encoding

    @property
    def current_observation(self) -> np.ndarray:
        observation = [
            *(self.simulation.torso_horizontal_position),
            *(self.simulation.actuator_state),
            *(self.current_target_position),
            #*(self.password_so_far_encoding),
            self.current_t/self.episode_timestep_n,
        ]
        obs = np.array(
            observation,
            dtype=np.float32,
        ).flatten()
        return obs

    def step(self, action: np.ndarray):
        ant_positions = self.simulation.step([action])

        # Button-platform logic internal model
        for ant_position in ant_positions:
            for i, button in enumerate(self.simulation.buttons):
                px, py, pr = button.x, button.y, button.r
                if distance(ant_position, (px, py)) < pr:
                    self.simulation.activate_button(i)

        self.current_t += 1
        info = dict()
        return (
            self.current_observation,
            self.current_reward,
            self.is_done,
            info,
        )

    def close(self):
        pass


class TightlyGuidedAntPasswordEnv(BaseAntPasswordEnv):
    """Button pushing environment that tightly guides the agent to the target
    password."""

    @property
    def is_done(self) -> bool:
        return any((
            self.current_t > self.episode_timestep_n,
            self.is_success,
            self.password_prefix_is_wrong,
        ))

    @property
    def password_prefix_is_wrong(self) -> float:
        # If we already pushed the wrong button, return negative
        prefix = self.target_password[:len(self.password_so_far)]
        if tuple(self.password_so_far) != tuple(prefix):
            return True
        return False

    @property
    def current_raw_reward(self) -> float:
        # Ref. bipedal walker
        if self.is_success:
            return 300.0

        # If we already pushed the wrong button, penalize
        if self.password_prefix_is_wrong:
            penalty = -5.0
        else:
            penalty = 0.0

        # Else, give reward proportional to the distance to the next button
        # in the password
        pos = torch.tensor(self.simulation.torso_horizontal_position)
        target_pos = torch.tensor(self.current_target_position)
        distance_to_target = (pos-target_pos).norm()

        # Provide bonus for moving through the password
        alpha = len(self.target_password)
        if len(self.target_password) > 0:
            password_reward = len(self.password_so_far)/len(self.target_password)
        else:
            password_reward = 0.0
        return -distance_to_target.item() + alpha*password_reward + penalty


class HierarchicalButtonPreTrainingEnv(gym.Env):
    """Pre-training environment that spawns random target positions."""

    def __init__(
            self,
            max_button_n: int,
            actuator_n: int,
            episode_timestep_n: int,
            sub_step_s: float,
            password_so_far_encoding_size: int,
            ):
        """
        It must be the case that
        `max_button_n <= password_so_far_encoding_size`.
        """
        assert max_button_n <= password_so_far_encoding_size
        self.max_button_n = max_button_n
        self.actuator_n = actuator_n
        self.episode_timestep_n = episode_timestep_n
        self.sub_step_s = sub_step_s
        self.password_so_far_encoding_size = password_so_far_encoding_size
        target_password, button_n = self.get_random_target_password()
        self.env = TightlyGuidedAntPasswordEnv(
            target_password=target_password,
            button_n=button_n,
            actuator_n=actuator_n,
            episode_timestep_n=episode_timestep_n,
            sub_step_s=sub_step_s,
            password_so_far_encoding_size=password_so_far_encoding_size,
        )
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space

    def reset(self):
        target_password, button_n = self.get_random_target_password()
        self.env = TightlyGuidedAntPasswordEnv(
            target_password=target_password,
            button_n=button_n,
            actuator_n=self.actuator_n,
            episode_timestep_n=self.episode_timestep_n,
            sub_step_s=self.sub_step_s,
            password_so_far_encoding_size=self.password_so_far_encoding_size,
        )
        return self.env.reset()

    def get_random_target_password(self) -> tuple[tuple[int, ...], int]:
        """Return target password and button_n."""
        button_n = random.randint(0, self.max_button_n)
        password_len = random.randint(0, button_n)
        if password_len > 0:
            password = random.sample(list(range(button_n)), k=password_len)
        else:
            password = tuple()
        return tuple(password), button_n

    def step(self, action: np.ndarray):
        return self.env.step(action)

    def render(self, mode: str = "human"):
        return self.env.render()

    def close(self):
        self.env.close()


class AntPasswordEnvEval(BaseAntPasswordEnv):
    """Button pushing evaluation environment."""
    @property
    def current_raw_reward(self) -> float:
        return -self.distance_to_goal

    @property
    def current_reward(self) -> float:
        return self.current_raw_reward


def rl_tight_guidance_solver(
        model_path: Path,
        vec_normalize_path: Path,
        target_password: tuple[int, ...],
        episode_timestep_n: int,
        actuator_n: int,
        button_n: int,
        rl_timestep_train_freq_n: int,
        rl_grad_steps_per_train: int,
        train_timestep_n: int,
        worker_n: int,
        seed: int,
        eval_freq: int,
        timeout_s: float,
        sub_step_s: float,
        password_so_far_encoding_size: int,
        verbose: bool,
        ) -> torch.Tensor:
    """Use a RL algorithm to optimize a policy that solves the task, and unroll
    the policy into a sequence of actions."""
    def make_eval_env():
        eval_env = AntPasswordEnvEval(
            target_password=target_password,
            button_n=button_n,
            actuator_n=actuator_n,
            episode_timestep_n=episode_timestep_n,
            sub_step_s=sub_step_s,
            password_so_far_encoding_size=password_so_far_encoding_size,
        )
        return eval_env

    # Helper function wo trap in VecNormalize, which tracks statistics
    # to normalize observations and rewards.
    def normalize_wrap(env):
        if isinstance(env, SubprocVecEnv):
            if vec_normalize_path is not None:
                env = VecNormalize.load(str(vec_normalize_path), env)
                print(f"Loaded {vec_normalize_path}")
            else:
                env = VecNormalize(env)
        return env
    eval_env = normalize_wrap(SubprocVecEnv([lambda: make_eval_env()]))

    # Load model
    model = SAC.load(str(model_path), eval_env)
    print(f"Loaded {model_path}")

    # Extract trajectory
    # The policy is executed in deterministic mode.
    actions = list()
    eval_rewards = list()
    observation = eval_env.reset()
    for _ in range(episode_timestep_n):
        action, _ = model.predict(observation, deterministic=True)
        actions.append(action.tolist())
        observation, _, done, _ = eval_env.step(action)
        eval_reward = eval_env.get_original_reward()[0]
        eval_rewards.append(eval_reward)
        if done:
            break

    # Verify at least one eval reward is greater than zero
    if max(eval_rewards) < 0.0:
        raise SolverFailedException(
            "Did not find solution!",
            final_parameters=torch.tensor(actions)
        )

    # Clean up environment resources
    eval_env.close()
    return torch.tensor(actions)


def pretrain_policy(
        pretrained_model_path: Optional[Path],
        replay_buffer_path: Optional[Path],
        vec_normalize_path: Optional[Path],
        episode_timestep_n: int,
        actuator_n: int,
        max_button_n: int,
        rl_timestep_train_freq_n: int,
        rl_grad_steps_per_train: int,
        train_timestep_n: int,
        worker_n: int,
        seed: int,
        eval_freq: int,
        timeout_s: float,
        sub_step_s: float,
        password_so_far_encoding_size: int,
        log_path: Path,
        verbose: bool,
        ) -> SAC:
    """Use a RL algorithm to optimize a policy that solves the task."""
    def make_env():
        return HierarchicalButtonPreTrainingEnv(
            max_button_n=max_button_n,
            actuator_n=actuator_n,
            episode_timestep_n=episode_timestep_n,
            sub_step_s=sub_step_s,
            password_so_far_encoding_size=password_so_far_encoding_size,
        )

    # Check env
    check_env(make_env())

    # Instantiate environment
    env = make_env()\
        if worker_n is None\
        else SubprocVecEnv([lambda: make_env()] * worker_n)

    # Helper function wo trap in VecNormalize, which tracks statistics
    # to normalize observations and rewards.
    def normalize_wrap(env):
        if isinstance(env, SubprocVecEnv):
            if vec_normalize_path is not None:
                env = VecNormalize.load(str(vec_normalize_path), env)
                print(f"Loaded {vec_normalize_path}")
            else:
                env = VecNormalize(env)
        return env

    env = normalize_wrap(env)

    # Instantiate the RL algorithm
    if pretrained_model_path is None:
        model = SAC(
            "MlpPolicy",
            env,
            #train_freq=(rl_timestep_train_freq_n, "step"),
            #gradient_steps=rl_grad_steps_per_train,
            #action_noise=NormalActionNoise(noise_mean, noise_sigma),
            #use_sde=True,
            verbose=2,
            seed=seed,
        )
    else:
        model = SAC.load(str(pretrained_model_path), env)
        print(f"Loaded {pretrained_model_path}")

    # Load replay buffer
    if replay_buffer_path is not None:
        model.load_replay_buffer(replay_buffer_path)
        print(f"Loaded {replay_buffer_path}")

    # Setup logger
    new_logger = configure(str(log_path), ["stdout", "json", "tensorboard"])
    model.set_logger(new_logger)

    # Setup a call-back to stop training as soon as the predicate is satisfied.
    # The call-back evaluates the robustness value of the predicate in an
    # entire episode. The policy is executed in deterministic mode.
    eval_env = make_env()\
        if worker_n is None\
        else SubprocVecEnv([lambda: make_env()] * worker_n)
    eval_env = normalize_wrap(eval_env)
    eval_callback = EvalCallback(
        eval_env,
        verbose=verbose,
        deterministic=True,
        eval_freq=eval_freq,
        log_path=str(log_path),
    )
    max_time_s = time.time() + timeout_s
    timeout_callback = EveryNTimesteps(
        n_steps=eval_freq,
        callback=TimeoutCallback(max_time_s=max_time_s),
    )

    # Optimize model
    model.learn(
        total_timesteps=train_timestep_n,
        callback=CallbackList([eval_callback, timeout_callback]),
        log_interval=eval_freq,
        reset_num_timesteps=False,
    )

    # Clean up environment resources
    env.close()
    eval_env.close()
    return model
