# Copyright (C) king.com Ltd 2025
# License: Apache 2.0

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt


class ContinuousTerminationEnv(gym.Env):
    """
    Continuous 2D environment where the agent must move to a goal and explicitly terminate.
    The agent's position, goal position, and actions are continuous.
    Now includes shaped rewards, larger goal tolerance, and staged difficulty progression for easier learning.
    """

    def __init__(self,
                 world_size=7.0,
                 max_steps=100,
                 goal_tolerance=0.5,
                 with_relative_dist_obs=False,
                 with_abs_pos_obs=True,
                 with_abs_goal_obs=False,
                 target_angle=0,
                 target_radius=2.9,
                 random_angle=False,
                 random_radius=False,
                 render_mode=None
                 ):
        super(ContinuousTerminationEnv, self).__init__()

        self.world_size = world_size
        self.grid_size = world_size
        self.max_steps = max_steps
        self.goal_tolerance = goal_tolerance
        self.with_relative_dist_obs = with_relative_dist_obs
        self.with_abs_pos_obs = with_abs_pos_obs
        self.with_abs_goal_obs = with_abs_goal_obs
        self.target_angle_rad = target_angle
        self.target_radius = target_radius
        self.random_angle = random_angle
        self.random_radius = random_radius
        self.possible_discrete_radii = [0.9, 1.9, 2.9]
        self.possible_discrete_angles = [0, 0.1, 0.15, 0.2, 0.3]
        self.render_mode = render_mode
        assert self.target_radius in self.possible_discrete_radii, "target_radius should be one of the possible discrete radii"

        # Define continuous action space: dx, dy, and terminate action [dx, dy, terminate]
        self.action_step_size = 0.1
        self.action_space = spaces.Box(low=np.array([-self.action_step_size, -self.action_step_size, 0.0], dtype=np.float32),
                                       high=np.array([self.action_step_size, self.action_step_size, 1.0], dtype=np.float32),
                                       dtype=np.float32)
        self.single_action_space = spaces.Box(low=np.array([-self.action_step_size, -self.action_step_size, 0.0], dtype=np.float32),
                                              high=np.array([self.action_step_size, self.action_step_size, 1.0], dtype=np.float32),
                                              dtype=np.float32)

        # Observation space: agent's (x, y), goal's (x, y), and relative distance (dx, dy)
        n_obs = 0
        if with_abs_pos_obs:
            n_obs += 2
        if with_abs_goal_obs:
            n_obs += 2
        if with_relative_dist_obs:
            n_obs += 2
        self.observation_space = spaces.Box(low=-self.world_size, high=self.world_size, shape=(n_obs,), dtype=np.float32)

        self.traj_states = []
        self.traj_actions = []
        self.traj_rewards = []
        self.traj_dones = []
        self.traj_infos = []
        self.traj_agent_poses = []
        self.traj_len = 0

        self.success = 0
        self.current_action = np.array([0, 0, 0], dtype=np.float32)
        self.agent_position = np.array([0.0, 0.0], dtype=np.float32)
        self.reset()
    def _info(self):
        return {
            "agent_pos": self.agent_position.copy(),
            "goal_pos": self.goal_position.copy(),
            "steps_taken": self.steps_taken,
            "success": self.success,
            "target_angle_rad": self.target_angle_rad,
            "target_radius": self.target_radius,
            "dense_reward": self._dense_reward(),
            "sparse_reward": self._sparse_reward(),
            "optimal_steps": np.ceil(np.linalg.norm(self.goal_position) / self.action_step_size),
        }

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        self.current_action = np.array([0, 0, 0], dtype=np.float32)
        self.agent_position = np.array([0.0, 0.0], dtype=np.float32)

        if self.random_angle:
            self.target_angle_rad = np.random.uniform(0, 2 * np.pi)

        if self.random_radius:
            self.target_radius = np.random.choice(self.possible_discrete_radii)

        if self.random_angle:
            self.target_angle_rad = np.random.choice(self.possible_discrete_angles) * np.pi

        if options is not None:
            for key, val in options.items():
                if hasattr(self, key) and val is not None:
                    setattr(self, key, val)
                else:
                    if not hasattr(self, key):
                        print(f"Given reset option {key} not found as attribute on environment, skipping...")
                    if val is None:
                        print(f"Given reset option {key} has None value, skipping...")


        radius = self.target_radius * 0.9
        goal_x = radius * np.cos(self.target_angle_rad)
        goal_y = radius * np.sin(self.target_angle_rad)
        self.goal_position = np.array([goal_x, goal_y], dtype=np.float32)

        self.steps_taken = 0
        self.terminated = False

        obs = self._get_observation()
        return obs, self._info()

    def _get_observation(self):
        o = []
        if self.with_abs_pos_obs:
            o.extend(self.agent_position)

        if self.with_abs_goal_obs:
            o.extend(self.goal_position)

        if self.with_relative_dist_obs:
            relative_position = self.goal_position - self.agent_position
            o.extend(relative_position)

        return np.array(o, dtype=np.float32)

    def _dense_reward(self):
        # Extract actions
        dx, dy, terminate_action = self.current_action

        reward = -0.01

        # Distance-based reward shaping
        distance = np.linalg.norm(self.agent_position - self.goal_position)
        reward -= distance

        if terminate_action > 0.9:
            if distance < self.goal_tolerance:
                reward += 10.0
            else:
                reward -= distance ** 2 * 10

        reward *= 0.1

        return reward

    def _sparse_reward(self):
        # Extract actions
        dx, dy, terminate_action = self.current_action

        distance = np.linalg.norm(self.agent_position - self.goal_position)
        reward = 0

        if terminate_action > 0.9:
            if distance < self.goal_tolerance:
                success_bonus = 10.0

                # discount bonus based on steps taken to reach goal
                optimal_steps = np.ceil(np.linalg.norm(self.goal_position) / self.action_step_size)
                discounted_bonus = np.clip(success_bonus * (optimal_steps / self.steps_taken), 0, success_bonus)
                reward += discounted_bonus
            else:
                reward -= distance ** 2 * 10

        return reward

    def step(self, action):
        old_obs = self._get_observation()

        self.current_action = action

        if self.terminated:
            raise RuntimeError("Episode has terminated. Please reset the environment.")

        self.steps_taken += 1

        # Extract actions
        dx, dy, terminate_action = action
        move_action = np.array([dx, dy], dtype=np.float32)

        # Termination action
        terminated = False
        self.success = 0
        distance = np.linalg.norm(self.agent_position - self.goal_position)
        if terminate_action > 0.9:
            terminated = True
            if distance < self.goal_tolerance:
                self.success = 1

        # project move action onto unit circle
        move_action /= np.linalg.norm(move_action) + 1e-6
        move_action *= self.action_step_size
        dx, dy = move_action

        # Update agent position with continuous movement
        if not terminated:
            self.agent_position[0] = np.clip(self.agent_position[0] + dx, -self.world_size, self.world_size)
            self.agent_position[1] = np.clip(self.agent_position[1] + dy, -self.world_size, self.world_size)

        # Check for step limit
        truncated = False
        if self.steps_taken >= self.max_steps:
            truncated = True

        self.terminated = terminated
        obs = self._get_observation()

        reward = self._dense_reward()

        info = self._info()

        if info["sparse_reward"] != 0:
            assert terminated, "Sparse reward should only be nonzero at last step in episode..."


        # log transition
        self.traj_states.append(old_obs)
        self.traj_actions.append(action)
        self.traj_rewards.append(reward)
        self.traj_dones.append(terminated)
        self.traj_infos.append(info)
        self.traj_len += 1
        self.traj_agent_poses.append(self.agent_position.copy())

        return obs, reward, terminated, truncated, info

    def render(self, mode='human', axs=(), save_path="", show=True, close=True):
        if len(axs) == 0:
            fig, axs = plt.subplots(nrows=2, ncols=1, height_ratios=[8, 1], figsize=(7, 7))
        else:
            assert len(axs) == 2, "Expecting 2 axes for rendering, one for env, one for obs heatmap"

        # Plot environment
        axs[0].set_xlim(-self.grid_size - 0.2, self.grid_size + 0.2)
        axs[0].set_ylim(-self.grid_size - 0.2, self.grid_size + 0.2)

        axs[0].plot(self.agent_position[0], self.agent_position[1], 'd', color='r', label='agent')
        axs[0].plot(self.goal_position[0], self.goal_position[1], 'o', color='g', label='goal')

        # plot trajectory
        for i in range(len(self.traj_states) - 1):
            axs[0].plot([self.traj_agent_poses[i][0], self.traj_agent_poses[i + 1][0]],
                        [self.traj_agent_poses[i][1], self.traj_agent_poses[i + 1][1]], color='k')

        axs[0].legend()
        axs[0].set_title(f"step {self.steps_taken}, reward_sum: {np.round(np.sum(self.traj_rewards), 2)}, target angle (rad): {np.round(self.target_angle_rad)}")

        plt.tight_layout()

        if save_path:
            print(f"Saving render img at {save_path}")
            if not os.path.exists(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path))
            plt.savefig(save_path)

        if show:
            plt.show()

        if close:
            plt.close(fig)

    def save_trajectory(self, save_path):
        if not os.path.exists(os.path.dirname(save_path)):
            os.makedirs(os.path.dirname(save_path))

        # get all data into shape [ep_len, dim]
        states = np.array(self.traj_states)
        assert states.ndim == 2, "States should be 2D"
        assert states.shape[0] == self.traj_len, "States should contain self.num_steps elements.."
        assert states.shape[1] == self.observation_space.shape[0], f"State dim seems incorrect. Should be {self.observation_space.shape[0]} but got {states.shape[1]}"

        actions = np.array(self.traj_actions)
        assert actions.ndim == 2, "Actions should be 2D"
        assert actions.shape[0] == states.shape[0], "Actions and states should have the same length"
        assert actions.shape[1] == self.action_space.shape[0], f"Action dim seems incorrect. Should be {self.action_space.shape[0]} but got {actions.shape[1]}"

        rewards = np.array(self.traj_rewards)
        assert rewards.ndim == 1, "Rewards should be 1D"
        assert rewards.shape[0] == states.shape[0], "Rewards and states should have the same length"

        dones = np.array(self.traj_dones)
        assert dones.ndim == 1, "Dones should be 1D"
        assert dones.sum() <= 1, "There should be at most one done in the trajectory, otherwise something is wrong here"

        infos = np.array(self.traj_infos)
        assert infos.ndim == 1
        assert infos.shape[0] == states.shape[0]

        with open(save_path, "wb") as f:
            pickle.dump({
                "observations": states,
                "actions": actions,
                "rewards": rewards,
                "terminals": dones,
                "infos": infos
            }, f)

        # reset trajectory logging lists
        self.traj_states = []
        self.traj_actions = []
        self.traj_rewards = []
        self.traj_dones = []
        self.traj_infos = []
        self.traj_agent_poses = []
        self.traj_len = 0
