import numpy as np
import copy


def make_window_open_policy():
    mode = 0

    def policy(obs_dict):
        nonlocal mode
        obs_dict = copy.deepcopy(obs_dict)
        eef_pos = obs_dict["state_observation"][:3]
        obs_pos = obs_dict["state_desired_goal"]
        pre_handle_pos = obs_pos + np.array([-0.3, -0.2, 0])
        handle_pos = obs_pos + np.array([-0.26, -0.03, 0])
        final_pos = obs_pos + np.array([-0.02, -0.03, 0])
        action = np.zeros(4)

        if mode == 0:
            action[:3] = (pre_handle_pos - eef_pos).copy()
            action[:3] /= np.linalg.norm(action[:3])
            action[:3] *= 0.6
            action[3] = 0
            if np.linalg.norm(pre_handle_pos - eef_pos) < 0.1:
                mode = 1
        if mode == 1:
            action[:3] = (handle_pos - eef_pos).copy()
            action[:3] /= np.linalg.norm(action[:3])
            action[:3] *= 0.6
            action[3] = 0
            if np.linalg.norm(handle_pos - eef_pos) < 0.1:
                mode = 2
        if mode == 2:
            action[:3] = (final_pos - eef_pos).copy()
            action[:3] /= np.linalg.norm(action[:3])
            action[:3] *= 0.6
            action[3] = 0

        return action + np.random.randn(4) * 0.1

    return policy


def make_window_close_policy():
    mode = 0

    def policy(obs_dict):
        nonlocal mode
        obs_dict = copy.deepcopy(obs_dict)
        eef_pos = obs_dict["state_observation"][:3]
        obs_pos = obs_dict["state_desired_goal"]
        pre_handle_pos = obs_pos + np.array([0.27, -0.2, 0])
        handle_pos = obs_pos + np.array([0.25, -0.03, 0])
        final_pos = obs_pos + np.array([0.02, -0.03, 0])
        action = np.zeros(4)

        if mode == 0:
            action[:3] = (pre_handle_pos - eef_pos).copy()
            action[:3] /= np.linalg.norm(action[:3])
            action[:3] *= 0.6
            action[3] = 0
            if np.linalg.norm(pre_handle_pos - eef_pos) < 0.1:
                mode = 1
        if mode == 1:
            action[:3] = (handle_pos - eef_pos).copy()
            action[:3] /= np.linalg.norm(action[:3])
            action[:3] *= 0.6
            action[3] = 0
            if np.linalg.norm(handle_pos - eef_pos) < 0.1:
                mode = 2
        if mode == 2:
            action[:3] = (final_pos - eef_pos).copy()
            action[:3] /= np.linalg.norm(action[:3])
            action[:3] *= 0.6
            action[3] = 0

        return action + np.random.randn(4) * 0.1

    return policy


def make_sweep_policy():
    mode = 0

    def policy(obs_dict):
        nonlocal mode
        obs_dict = copy.deepcopy(obs_dict)
        eef_pos = obs_dict["state_observation"][:3]
        obs_pos = obs_dict["state_observation"][22:25]
        goal_pos = obs_dict["state_desired_goal"]
        pre_obs_pos = obs_pos + np.array([-0.1, 0, 0])
        action = np.zeros(4)

        if mode == 0:
            action[:3] = (pre_obs_pos - eef_pos).copy()
            action[:3] /= np.linalg.norm(action[:3])
            action[:3] *= 0.6
            action[3] = 0
            if np.linalg.norm(pre_obs_pos - eef_pos) < 0.1:
                mode = 1
        if mode == 1:
            action[:3] = (obs_pos - eef_pos).copy()
            action[:3] /= np.linalg.norm(action[:3])
            action[:3] *= 0.6
            action[3] = 0.2
            if np.linalg.norm(obs_pos - eef_pos) < 0.05:
                mode = 2
                action[3] = 1
        if mode == 2:
            action[:3] = (goal_pos - eef_pos).copy()
            action[:3] /= np.linalg.norm(action[:3])
            action[:3] *= 0.6
            action[3] = 1

        action[:3] += np.random.randn(3) * 0.1
        return action

    return policy


def reach_policy(obs_dict):
    obs_dict = copy.deepcopy(obs_dict)
    eef_pos = obs_dict["state_observation"][:3]
    obs_pos = obs_dict["state_desired_goal"]
    obs_pos[2] += 0.05
    action = np.zeros(4)
    action[:3] = (obs_pos - eef_pos).copy()
    action[:3] /= np.linalg.norm(action[:3])
    action[:3] *= 0.6
    action[3] = 1
    action[:3] += np.random.randn(3) * 0.1
    return action


def make_policy(env_id):
    if "window-open" in env_id:
        return make_window_open_policy()
    elif "window-close" in env_id:
        return make_window_close_policy()
    elif env_id == "sweep-v2":
        return make_sweep_policy()
    elif "reach" in env_id:
        return reach_policy
    else:
        raise ValueError(f"Unknown env_id: {env_id}")
