# REACHABILITY MODEL


class ReachabilityModel():
    
    def __init__(self,
        agent,
        trajectory_model,
        goal_space_low,
        goal_space_high,
        landmarks_info,
        success_threshold,
        ll_arena_low,
        ll_arena_high,
        relative_goals,
        work_dir
    ):
        self.agent = agent
        self.work_dir = work_dir
        self.trajectory_model = trajectory_model
        self.goal_space_low = goal_space_low
        self.goal_space_high = goal_space_high
        self.goal_dim = len(self.goal_space_low)
        self.ll_arena_radius = 5.
        self.ll_arena_low = ll_arena_low
        self.ll_arena_high = ll_arena_high
        self.relative_goals = relative_goals
        self.accessibility_density_threshold = 0.01
        self.pdf = None
        self.ocsvm_ensemble = []
        self.bounding_boxes = []
        self.nn_ensemble = []
        self.nn_optims = []
        self.pdf_resolution_imag = 100j
        self.pdf_resolution = int(self.pdf_resolution_imag.imag)
        self.success_threshold = success_threshold
        self.landmarks_info = landmarks_info
        self.max_batch_size = 10000
        self.nu = 0.01
        self.gamma = 0.1
        
        
    def is_initialized(self):
        # return self.pdf is not None
        return len(self.ocsvm_ensemble) > 0
        
    def infer(self, achieved_goals, landmark_idx, diagram_save_path=None):

        X_train = achieved_goals
        
        # Fit the One-Class SVM
        
        new_ocsvm = OneClassSVM(gamma=self.gamma, kernel="rbf", nu=self.nu)
        new_ocsvm.fit(X_train)
        
        # Kernel approximation
        # transform = Nystroem(gamma=self.gamma, random_state=42)
        # clf_sgd = SGDOneClassSVM(
        #     nu=self.nu, shuffle=True, fit_intercept=True, random_state=42, tol=1e-4
        # )
        # new_ocsvm = make_pipeline(transform, clf_sgd)

        new_ocsvm.fit(X_train)
        
        new_bounding_box = {
            'min': np.min(achieved_goals, axis=0),
            'max': np.max(achieved_goals, axis=0)
        }
        
        if landmark_idx >= len(self.ocsvm_ensemble):
            self.ocsvm_ensemble.append(new_ocsvm)
            self.bounding_boxes.append(new_bounding_box)
        else:
            self.ocsvm_ensemble[landmark_idx] = new_ocsvm
            self.bounding_boxes[landmark_idx] = new_bounding_box
        

        if diagram_save_path:
            # OCSVM
            
            fig, ax = plt.subplots(figsize=(9, 6))
            
            for ocsvm in self.ocsvm_ensemble:

                xx, yy = np.meshgrid(np.linspace(self.goal_space_low[0], self.goal_space_high[0], 50), np.linspace(self.goal_space_low[1], self.goal_space_high[1], 50))
                X = np.concatenate([xx.ravel().reshape(-1, 1), yy.ravel().reshape(-1, 1)], axis=1)
                
                
                DecisionBoundaryDisplay.from_estimator(
                    ocsvm,
                    X,
                    response_method="decision_function",
                    plot_method="contourf",
                    ax=ax,
                    colors="palevioletred",
                    levels=[0, ocsvm.decision_function(X).max()],
                )
                
                mat = self.agent.ppo_meta_agent.landmark_reachability_matrix
                if mat is not None:
                    reachability_matrix_draw = copy.deepcopy(mat)
                    
                    for i in range(len(reachability_matrix_draw)):
                        reachability_matrix_draw[i, i] = 0.
                        
                    rows_draw, cols_draw = np.where(reachability_matrix_draw == 1.)
                    edges_draw = zip(rows_draw.tolist(), cols_draw.tolist())
                    gr_draw = nx.DiGraph()
                    for i, node in enumerate([lm_info['landmark'] for lm_info in self.landmarks_info]):
                        gr_draw.add_node(i, pos=node)
                    gr_draw.add_edges_from(edges_draw)
                    
                    pos_draw=nx.get_node_attributes(gr_draw,'pos')
                        
                
                    node_sizes = [40 for i in range(len(gr_draw))]
                    
                    nodes = nx.draw_networkx_nodes(gr_draw, pos_draw, node_size=node_sizes, node_color="indigo")
                    nx_edges_draw = nx.draw_networkx_edges(
                        gr_draw,
                        pos_draw,
                        node_size=node_sizes,
                        arrowstyle="->",
                        arrowsize=10,
                        width=2,
                    )
                
            ax.set(
                xlim=(self.goal_space_low[0], self.goal_space_high[0]),
                ylim=(self.goal_space_low[1], self.goal_space_high[1]),
            )
            
            fig.savefig(diagram_save_path)
            
    def is_accessible(self, goals):
        region_inclusion_sum = np.zeros((len(goals),))
        
        for i in range(len(self.ocsvm_ensemble)):
            box = self.bounding_boxes[i]
            box_inclusion = np.multiply(goals > box['min'], goals < box['max'])
            box_inclusion = np.prod(box_inclusion, axis=1)
            # For computational efficiency, only check points in the given bounding box which are not in previously checked regions.
            # For points which were already determined to be in a previous region, assume they are NOT
            # in this region. This may not be true but we only need to know if a point is in ANY region,
            # rather than knowing all regions it is a part of.
            points_to_check_idx = np.where(np.multiply(box_inclusion == 1, region_inclusion_sum == 0))
            points_to_check = goals[points_to_check_idx]
            if len(points_to_check) == 0:
                continue
            region_inclusion_sum[points_to_check_idx] += self.ocsvm_ensemble[i].predict(points_to_check) == 1
            
        accessible = region_inclusion_sum > 0
        return accessible
    
    def is_reachable(self, g1, g2, print_bb_info=False):
        res = self.is_reachable_(g1, g2, print_bb_info=print_bb_info)
        return res[0]
    
    def is_reachable_(self, g1, g2, print_bb_info=False, shield=False):
        # Vectorised implementation, returns probability that g2 is reachable from g1
        
        if len(g1.shape) < 2:
            g1 = np.expand_dims(g1, axis=0)
        if len(g2.shape) < 2:
            g2 = np.expand_dims(g2, axis=0)
            
        disagreement_g1 = np.copy(g1)
        disagreement_g2 = np.copy(g2)
        
        if self.relative_goals:
            g2 = g2 - g1
            low = self.ll_arena_low
            high = self.ll_arena_high
        else:
            low = g1 + self.ll_arena_low
            high = g1 + self.ll_arena_high
        
        viable_idx_mask = np.multiply(np.prod(g2 < high, axis=1),  np.prod(g2 > low, axis=1))
        
        if np.sum(viable_idx_mask) == 0:
            reachable_prob = np.stack([np.ones((len(g2),)), np.zeros((len(g2),))], axis=-1)
            accessible_prob = np.zeros((len(g2)))
            endpoint_prob = np.zeros((len(g2)))
            return reachable_prob, accessible_prob, endpoint_prob, {}
        
        viable_idx = np.nonzero(viable_idx_mask)
        viable_g1 = g1[viable_idx]
        viable_g2 = g2[viable_idx]
        disagreement_viable_g1 = disagreement_g1[viable_idx]
        disagreement_viable_g2 = disagreement_g2[viable_idx]
        viable_low = low if self.relative_goals else low[viable_idx]
        viable_high = high if self.relative_goals else high[viable_idx]
        
        if self.relative_goals:
            trajectory_model_in = torch.FloatTensor(viable_g2)
        else:
            trajectory_model_in = torch.FloatTensor(np.concatenate([viable_g1, viable_g2], axis=1))
            
        predicted_trajectory_batches = []
        elem_idx = 0
        while elem_idx < len(trajectory_model_in):
            in_batch = trajectory_model_in[elem_idx : elem_idx + self.max_batch_size, :]
            predicted_trajectory_batches.append(
                self.trajectory_model(in_batch).detach().cpu().numpy()
            )
            elem_idx += len(in_batch)
        
        predicted_trajectory = np.concatenate(predicted_trajectory_batches, axis=0)
        g1_broadcast = np.expand_dims(viable_g1, axis=1)
        
        if self.relative_goals:
            # Back to absolute goals
            predicted_trajectory = predicted_trajectory + g1_broadcast
            viable_g2 = viable_g1 + viable_g2
            viable_low = viable_g1 + viable_low
            viable_high = viable_g1 + viable_high
        
        predicted_endpoints = predicted_trajectory[:, -1, :]
        
        pred_traj_shape = predicted_trajectory.shape
        
        test_traj = np.copy(predicted_trajectory[:100, :, :])
        
        X_test = np.reshape(predicted_trajectory, (pred_traj_shape[0] * pred_traj_shape[1], pred_traj_shape[2]))
        
        accessible_goals = self.is_accessible(X_test)
        
        total_accessible_ratio = np.sum(accessible_goals) / len(accessible_goals)
        
        accessible_goals = np.reshape(accessible_goals, (pred_traj_shape[0], pred_traj_shape[1]))
        
        accessible_ratios = np.array([np.sum(elem) / len(elem) for elem in accessible_goals])
        
        info = {
            'total_accessible_ratio': total_accessible_ratio,
            'accessible_ratios': accessible_ratios,
            'mean_accessible_ratio': np.mean(accessible_ratios),
            'test_traj': test_traj
        }

        reachable_prob = np.prod(accessible_goals, axis=1)
        
        endpoint_success_prob = np.linalg.norm(predicted_endpoints - viable_g2, axis=1) < self.success_threshold
        
        endpoint_success_prob = np.multiply(endpoint_success_prob, np.prod(viable_g2 < viable_high, axis=1))
        endpoint_success_prob = np.multiply(endpoint_success_prob, np.prod(viable_g2 > viable_low, axis=1))
        
        accessible_prob = np.copy(reachable_prob)
        reachable_prob = np.multiply(reachable_prob, endpoint_success_prob)
        
        full_accessible_prob = np.zeros((len(g2),))
        full_accessible_prob[viable_idx] = accessible_prob
        full_endpoint_prob = np.zeros((len(g2),))
        full_endpoint_prob[viable_idx] = endpoint_success_prob
        
        full_reachable_prob = np.zeros((len(g2),))
        full_reachable_prob[viable_idx] = reachable_prob
        
        unreachable_prob = 1. - full_reachable_prob
        return np.stack([unreachable_prob, full_reachable_prob], axis=-1), full_accessible_prob, full_endpoint_prob, info
    
    def saved_model_available(self, load_dir):
        return os.path.isfile(os.path.join(load_dir, 'ocsvm_accessibility_model.pkl'))
    
    def save(self, save_dir):
        with open(os.path.join(save_dir, 'ocsvm_accessibility_model.pkl'), 'wb') as f:
            pickle.dump(
                {'models': self.ocsvm_ensemble, 'bounding_boxes': self.bounding_boxes},
                f
            )
        torch.save(self.disagreement_model, os.path.join(save_dir, 'nn_accessibility_model.pth'))
            
    def load(self, load_dir):
        with open(os.path.join(load_dir, 'ocsvm_accessibility_model.pkl'), 'rb') as f:
            data = pickle.load(f)
            self.ocsvm_ensemble = data['models']
            self.bounding_boxes = data['bounding_boxes']
        self.disagreement_model = torch.load(os.path.join(load_dir, 'nn_accessibility_model.pth'))
        
        
        
        

# EXPLORATION FUNCTION
        
        
def explore(
    landmarks_info,
    unf_landmarks_info,
    traj_data,
    num_eps,
    work_dir,
    env: gym.Env,
    agent: Union[SACMetaAgent, PPOMetaAgent, HERMetaAgent],
    reachability_model_dict,
    accessibility_model,
    trajectory_model,
    exploration_histogram,
    global_low,
    global_high,
    ll_arena_low,
    ll_arena_high,
    explorable_low,
    explorable_high,
    n_accessibility_points,
    attempted_goals,
    success_threshold,
    meta_period,
    get_achieved_goal,
    get_desired_goal,
    round_idx,
    global_steps,
    eval_step,
    evaluate_fn,
    cfg=None
):
    
    landmarks = [lm_info['landmark'] for lm_info in landmarks_info]
    reachability_model = reachability_model_dict['model']
    reachability_model.train()
    
    goal_dim = len(ll_arena_low)
    
    
    seed = None
    unfinished_landmarks = np.stack([lm_info['landmark'] for lm_info in unf_landmarks_info])

    step = 0
    trial = 0
    total_rewards: List[float] = []
    round_no_reachable_landmarks_count = 0
    explored_from_landmarks = set()
    
    all_start_goals = []
    all_end_goals = []
    all_reached_labels = []
    
    
    
    while True:
        
        ep_obs = []
        ep_segments = []
        ep_directly_reachable_obs = []
        
        episode_step = 0
        np.random.seed(seed)
        obs, info = env.reset()
        current_segment = [get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze()]
        
        agent.ppo_meta_agent.set_random_actions(False)
        
        unf_landmarks_reachable_info = [
            {
                'id': lm_info['id'],
                'reachable': agent.ppo_meta_agent.towards_landmark(obs, lm_info['id'], default_action=False, initial_step=episode_step == 0) is not None
            } for lm_info in unf_landmarks_info
        ]
        reachable_landmarks = [landmarks_info[reachable_info['id']] for reachable_info in unf_landmarks_reachable_info if reachable_info['reachable']]
        explore_from_landmark_idx = reachable_landmarks[-1]['id'] if len(reachable_landmarks) > 0 else None
        if explore_from_landmark_idx is None:
            round_no_reachable_landmarks_count += 1
            explore_from_landmark_idx = unf_landmarks_info[-1]['id']
        
        uniform_exploration = landmarks_info[explore_from_landmark_idx]['cycle'] == 0 if explore_from_landmark_idx is not None else True
        candidate_landmark = landmarks_info[explore_from_landmark_idx]['candidate'] if not uniform_exploration else None
        
        terminated = False
        truncated = False
        total_reward = 0.0
        meta_action = None
        ag_origin = None
        init_orientation = None
        
        reached_latest_landmark = False
        reached_candidate_landmark = False
            
        exploration_started = False
        exploration_meta_steps = 0
        meta_action = None
        desired_coord = None
        
        steps_since_last_control = 0
        meta_step = 0
        needs_correction = False
        failed_meta_steps = 0
        failure_tolerance = 2
        
        ep_start_goals = []
        ep_end_goals = []
        ep_reached_labels = []
        prev_desired_coord = None
        
        ep_step = -1

        while not terminated and not truncated:
            
            ep_step += 1
            
            meta_control_step = False
            
            if episode_step == 0:
                meta_control_step = True
                
            if steps_since_last_control == meta_period:
                meta_control_step = True
                
            if needs_correction:
                meta_control_step = True
                
            achieved_goal = get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze()
            if desired_coord is not None and np.linalg.norm(achieved_goal - desired_coord) < success_threshold:
                meta_control_step = True
                
            
            if meta_control_step:
                
                steps_since_last_control = 0
                
                ag_origin = np.copy(get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze())
                print('ag origin:', ag_origin)
                
                if explore_from_landmark_idx is not None:
                    if  np.linalg.norm(get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze() - landmarks[explore_from_landmark_idx]) < success_threshold:
                        reached_latest_landmark = True
                        
                # Edge case
                if episode_step == 0 and reached_latest_landmark:
                    ep_obs.append(obs)
                
                if needs_correction:
                    coord = True
                    low = ag_origin + ll_arena_low
                    high = ag_origin + ll_arena_high
                    
                    meta_action = np.random.uniform(low, high)
                    
                    print(f'----------------------------------------------------------------------------------------------------Used correction! Meta step: {meta_step}')
                else:
                    meta_action, coord, exploration_started = explore_or_exploit(
                        episode_step == 0,
                        reached_latest_landmark,
                        explore_from_landmark_idx,
                        accessibility_model,
                        exploration_meta_steps,
                        landmarks_info,
                        agent,
                        goal_dim,
                        ag_origin,
                        obs,
                        global_low,
                        global_high,
                        ll_arena_low,
                        ll_arena_high,
                        explorable_low,
                        explorable_high,
                        exploration_started,
                        step + 1,
                        n_accessibility_points,
                        attempted_goals,
                        exploration_histogram
                    )
                    
                if coord:
                    desired_coord = meta_action
                elif meta_action[0] == 1:
                    desired_coord = get_desired_goal(np.expand_dims(obs, axis=0)).squeeze()
                else:
                    idx = np.argmax(meta_action)
                    desired_coord = landmarks[idx - 1]
                    
                
                ep_start_goals.append(ag_origin)
                ep_end_goals.append(desired_coord)
                
                if prev_desired_coord is not None:
                    ep_reached_labels.append(
                        np.linalg.norm(ag_origin - prev_desired_coord) < success_threshold
                    )
                
                prev_desired_coord = desired_coord
                
                ag = get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze()
                print(f'meta step: {meta_step}, meta action: {meta_action}')
                if exploration_started:
                    exploration_meta_steps += 1
                    
                    
                if len(current_segment) > 1 and reached_latest_landmark:
                    ep_segments.append(copy.deepcopy(current_segment))
                current_segment = [get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze()]
                
                if exploration_meta_steps == 1:
                    ep_directly_reachable_obs.append(obs)
                    
                    
                needs_correction = False
            
                    
                meta_step += 1
                    
            if global_steps > 0 and global_steps % eval_step == 0:
                evaluate_fn(global_steps)
                    
            action = agent.act(obs, meta_action, achieved_goal_origin=ag_origin, initial_orientation=init_orientation, coord=coord, should_print=meta_control_step)
            next_obs, reward, terminated, truncated, info = env.step(action)
            
            landmarks_info[explore_from_landmark_idx]['total_exploration_steps'] += 1
            if reached_latest_landmark:
                landmarks_info[explore_from_landmark_idx]['useful_exploration_steps'] += 1
            
            # Check if correction is needed
            next_ag = get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze()
            if steps_since_last_control == meta_period - 1 and np.linalg.norm(next_ag - desired_coord) > success_threshold:
                failed_meta_steps += 1
            if failed_meta_steps >= failure_tolerance:
                needs_correction = True
                failed_meta_steps = 0
            
            if reached_latest_landmark:
                ep_obs.append(next_obs)
            current_segment.append(get_achieved_goal(np.expand_dims(next_obs, axis=0)).squeeze())
            if exploration_meta_steps == 1:
                ep_directly_reachable_obs.append(next_obs)
            
            
            obs = next_obs
            step += 1
            episode_step += 1
            global_steps += 1
            steps_since_last_control += 1
            
        ep_start_goals = ep_start_goals[:-1]
        ep_end_goals = ep_end_goals[:-1]
        
        all_start_goals.extend(ep_start_goals)
        all_end_goals.extend(ep_end_goals)
        all_reached_labels.extend(ep_reached_labels)
        
        if explore_from_landmark_idx is not None:
            landmarks_info[explore_from_landmark_idx]['n_obs_collected'] = landmarks_info[explore_from_landmark_idx]['n_obs_collected'] + len(ep_obs)
            
        if landmarks_info[explore_from_landmark_idx]['n_obs_collected'] > 0:
            explored_from_landmarks.add(explore_from_landmark_idx)
           
            if True:
                traj_data[explore_from_landmark_idx]['obs'].extend(ep_obs)
                traj_data[explore_from_landmark_idx]['segments'].extend(ep_segments)
                traj_data[explore_from_landmark_idx]['directly_reachable_obs'].extend(ep_directly_reachable_obs)
            
        trial += 1
        if collect_full_trajectories and trial == num_eps:
            break
    
    reachability_data = {
        'start_goals': all_start_goals,
        'end_goals': all_end_goals,
        'reached_labels': all_reached_labels
    }
    
    return traj_data, reachability_data, explored_from_landmarks, round_no_reachable_landmarks_count, global_steps, attempted_goals








def explore_or_exploit(
        initial_step,
        reached_latest_landmark,
        explore_from_landmark_idx,
        accessibility_model,
        exploration_meta_steps,
        landmarks_info,
        agent,
        goal_dim,
        ag_origin,
        obs,
        global_low,
        global_high,
        ll_arena_low,
        ll_arena_high,
        explorable_low,
        explorable_high,
        exploration_started,
        n_collected,
        n_accessibility_points,
        attempted_goals,
        exploration_histogram
):

    if reached_latest_landmark or explore_from_landmark_idx is None or (not accessibility_model.is_initialized()):
        
        exploration_started = True
        
            
        if explore_from_landmark_idx is not None and reached_latest_landmark:
            explore_from_lm = landmarks_info[explore_from_landmark_idx]['landmark']
            print(f'--------------------------------------------------------------------------------------------Reached latest landmark {explore_from_lm}, doing random walk now')
        else:
            print('No reachable landmark, doing random walk now')
    
        
        coord = True
        meta_action = get_random_walk_meta_action(
            agent,
            goal_dim,
            ag_origin,
            reached_latest_landmark,
            explore_from_landmark_idx,
            landmarks_info,
            accessibility_model,
            global_low,
            global_high,
            ll_arena_low,
            ll_arena_high,
            explorable_low,
            explorable_high,
            n_collected,
            n_accessibility_points,
            attempted_goals,
            exploration_histogram
        )
            
    else:
        coord = False
        agent.ppo_meta_agent.set_random_actions(False)
        meta_action = agent.ppo_meta_agent.towards_landmark(obs, explore_from_landmark_idx, initial_step=initial_step)
        
    return meta_action, coord, exploration_started



def get_random_walk_meta_action(
        agent,
        goal_dim,
        ag_origin,
        reached_latest_landmark,
        explore_from_landmark_idx,
        landmarks_info,
        accessibility_model,
        global_low,
        global_high,
        ll_arena_low,
        ll_arena_high,
        explorable_low,
        explorable_high,
        n_collected,
        n_accessibility_points,
        attempted_goals,
        exploration_histogram
):
    agent.ppo_meta_agent.set_random_actions(False)
    
    local_low = ag_origin + ll_arena_low
    local_high = ag_origin + ll_arena_high
    intersection_low = np.max(
        np.stack([local_low, global_low]),
        axis=0
    )
    intersection_high = np.min(
        np.stack([local_high, global_high]),
        axis=0
    )
        
    meta_action = np.random.uniform(intersection_low, intersection_high)
    
    if reached_latest_landmark:
        attempted_goals.append(meta_action)
        
    return meta_action






def cartesian_product(*arrays):
    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        arr[...,i] = a
    return arr.reshape(-1, la)






# LANDMARK GENERATION FUNCTION
    
    
def generate_frontier_landmark(
    starting_landmark_idx,
    traj_data,
    attempted_goals,
    work_dir,
    round_idx,
    cycle_idx,
    success_threshold,
    update_novel_props,
    reachability_model_dict,
    accessibility_model,
    landmarks_info,
    low, # must be finite
    high, # must be finite
    ll_arena_low,
    ll_arena_high,
    meta_period,
    thresh_prob,
    get_achieved_goal
):
    
    reachability_model = reachability_model_dict['model']
    reachability_model.eval()
    landmarks = [lm_info['landmark'] for lm_info in landmarks_info]
    
    all_segments = traj_data['segments']  
    step = 1
    selected_obs = [np.array(segment[::step]) for segment in all_segments]
    selected_obs = np.concatenate(selected_obs, axis=0)
    all_achieved_goals = get_achieved_goal(selected_obs)
    coord_dim = all_achieved_goals.shape[1]
    
    

    def find_novel_achieved_goals(achieved_goals):  
        reachables_list = []
        accessibles_list = []
        endpoint_reachables_list = []
        for i, landmark in enumerate(landmarks):
            coord_dim = achieved_goals.shape[1]
            model_in = np.zeros((len(achieved_goals), coord_dim * 2))
            model_in[:, :coord_dim] = landmark
            model_in[:, coord_dim:] = achieved_goals
            
            reachable, accessible, endpoint_reachable, _ = accessibility_model.is_reachable_(model_in[:, :coord_dim], model_in[:, coord_dim:])
            
            accessible = (accessible >= thresh_prob).squeeze()
            endpoint_reachable = (endpoint_reachable >= thresh_prob).squeeze()
            accessibles_list.append(accessible)
            endpoint_reachables_list.append(endpoint_reachable)
            
            reachable = (reachable[:, 1] >= thresh_prob).squeeze()
            reachables_list.append(reachable)
            
        reachables_arr = np.stack(reachables_list)
        reachables_sum = np.sum(reachables_arr, axis=0)
        
        novel = achieved_goals[np.where(reachables_sum == 0)].tolist()
        not_novel = achieved_goals[np.where(reachables_sum != 0)].tolist()
        
        return novel, not_novel
    
    novel_ag_list, not_novel_ag_list = find_novel_achieved_goals(all_achieved_goals)
    
    wall_low = np.array([0.25, 0.1, -0.2])
    wall_high = np.array([0.35, 0.8, 0.2])
    wall_data = np.random.uniform(wall_low, wall_high, size=(1000, 3))
    
    if len(low) == 3:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        data = np.stack(landmarks)
        ax.scatter(
            data[:, 0],
            data[:, 1],
            data[:, 2]
        )
        ax.scatter(
            wall_data[:, 0],
            wall_data[:, 1],
            wall_data[:, 2]
        )
        plt.savefig(os.path.join(work_dir, f'r{round_idx}_lm_{starting_landmark_idx}_c{cycle_idx}_lms.jpg'))
        
        data_novel = novel_ag_list
        data_not_novel = not_novel_ag_list

        if len(data_novel) > 0 or len(data_not_novel) > 0:
            fig = plt.figure()
            ax = fig.add_subplot(projection='3d')
            if len(data_not_novel) > 0:
                data_not_novel = np.stack(data_not_novel)
                ax.scatter(
                    data_not_novel[:, 0],
                    data_not_novel[:, 1],
                    data_not_novel[:, 2]
                )
            if len(data_novel) > 0:
                data_novel = np.stack(data_novel)
                ax.scatter(
                    data_novel[:, 0],
                    data_novel[:, 1],
                    data_novel[:, 2]
                )
            ax.scatter(
                data[:, 0],
                data[:, 1],
                data[:, 2]
            )
            ax.scatter(
                wall_data[:, 0],
                wall_data[:, 1],
                wall_data[:, 2]
            )
            plt.savefig(os.path.join(work_dir, f'r{round_idx}_lm_{starting_landmark_idx}_c{cycle_idx}_scatter_all.jpg'))
            
    
    save_wip_landmarks = len(low) == 2
    
    if save_wip_landmarks:
        
        plt.figure()
        if len(not_novel_ag_list) > 0:
            plt.scatter(np.stack(not_novel_ag_list)[:, 0], np.stack(not_novel_ag_list)[:, 1], alpha=0.3)
        if len(novel_ag_list) > 0:
            plt.scatter(np.stack(novel_ag_list)[:, 0], np.stack(novel_ag_list)[:, 1], alpha=0.3)
        if len(attempted_goals) > 0:
            plt.scatter(np.stack(attempted_goals)[:, 0], np.stack(attempted_goals)[:, 1], alpha=1.)
        plt.xlim(low[0], high[0])
        plt.ylim(low[1], high[1])
        plt.savefig(os.path.join(work_dir, f'r{round_idx}_lm_{starting_landmark_idx}_c{cycle_idx}_scatter_all.jpg'))
    
    novel_ag = set(tuple(elem) for elem in novel_ag_list)
    novel_prop = len(novel_ag) / len(all_achieved_goals)
    
    
    novel_prop_not_improving = False
    if update_novel_props:
        if 'novel_prop_list' not in landmarks_info[starting_landmark_idx]:
            landmarks_info[starting_landmark_idx]['novel_prop_list'] = [novel_prop]
        else:
            landmarks_info[starting_landmark_idx]['novel_prop_list'].append(novel_prop)
            prop_list = landmarks_info[starting_landmark_idx]['novel_prop_list']
            novel_prop_not_improving = len(prop_list) > 1 and np.abs(prop_list[-1] - prop_list[-2]) < 0.005
    
    if novel_prop == 0.:
        return None, False
    
    coords = []
    novel_props = []
    for segment in all_segments:
        novel = []
        for ag in segment:
            if tuple(ag.tolist()) in novel_ag:
                novel.append(True)
            else:
                novel.append(False)
        novel = np.array(novel)
        for i, ag in enumerate(segment):
            novel_prop = np.sum(novel[i:]) / len(novel[i:])
            if all(ag > low) and all(ag < high):
                coords.append(ag)
                novel_props.append(novel_prop)
    coords = np.stack(coords)
    cl = np.zeros_like(coords)
    cl[:, :] = landmarks[starting_landmark_idx]
    
    novel_props = np.array(novel_props)
    
    

    n = 50
    
    coord_dim = len(low)
    thresh = thresh_prob
    
    ff_in_shape = tuple([n] * coord_dim)
    ff = np.zeros(ff_in_shape)
    counts = np.zeros_like(ff)
    for i, (coord, novel_prop) in enumerate(zip(coords, novel_props)):
        bin_idx = tuple(((coord - low) * n / (high - low)).astype(dtype=np.uint8))
        ff[bin_idx] += novel_prop
        counts[bin_idx] += 1
    if np.isnan(np.sum(ff)):
        print('isnan 0')
        raise Exception()
    old_old_sum = np.copy(np.sum(ff))
    old_old_ff = np.copy(ff)
    
    
    rp = np.zeros_like(ff)
    ac = np.zeros_like(ff)
    rp_coords = []
    
    bin_sizes = np.divide(high - low, np.array(rp.shape))
    binned_coord_arrs = []
    for i_dim in range(coord_dim):
        bin_size = bin_sizes[i_dim]
        binned_coord_arrs.append(
            np.array([low[i_dim] + (i_bin + 0.5) * bin_size for i_bin in range(n)])
        )
    binned_coord_arrs = tuple(binned_coord_arrs)
    
    rp_coords = cartesian_product(*binned_coord_arrs)
    
    rp_cl = np.zeros_like(rp_coords)
    rp_cl[:, :] = landmarks[starting_landmark_idx]
            
    coords_reachable_probs = accessibility_model.is_reachable(rp_cl, rp_coords)
    coords_accessible_probs = accessibility_model.is_accessible(rp_coords)
    
    max_reachable = np.max(coords_reachable_probs, axis=0).squeeze()[1]
    should_thresh = max_reachable >= thresh
    
    for i, (coord, reachable_prob, accessible_prob) in enumerate(zip(rp_coords, coords_reachable_probs, coords_accessible_probs)):
        bin_idx = tuple(((coord - low) * n / (high - low)).astype(dtype=np.uint8))
        rp[bin_idx] = reachable_prob[1]
        ac[bin_idx] = accessible_prob
    
    ff = np.multiply(ac, ff)
    ff = np.multiply(rp, ff)
    ff = np.divide(ff, counts, out=np.zeros_like(ff), where=counts!=0)
    
            
    if save_wip_landmarks:
        
        x = np.linspace(low[0], high[0], n)
        y = np.linspace(low[1], high[1], n)
            
    if np.isnan(np.sum(ff)):
        raise Exception()
    ff_sum = np.copy(np.sum(ff))
    
    
    if ff_sum == 0.:
        return None, False
    
    ff = ff / np.sum(ff)
    if np.isnan(np.sum(ff)):
        if save_wip_landmarks:
            
            plt.figure()
            plt.scatter(all_achieved_goals[:, 0], all_achieved_goals[:, 1], alpha=0.3)
            plt.title('AG to generate lm')
            plt.savefig(os.path.join(work_dir, f'error_ag.jpg'))
            
            plt.figure()
            plt.contourf(x, y, ff.T)
            plt.axis('scaled')
            plt.colorbar()
            plt.savefig(os.path.join(work_dir, f'error_contour_should_thresh_{should_thresh}.jpg'))
            
            plt.figure()
            plt.contourf(x, y, (counts/np.max(counts)).T)
            plt.axis('scaled')
            plt.colorbar()
            plt.savefig(os.path.join(work_dir, f'error_counts.jpg'))
        raise Exception()
    
    

    best_new_landmark_bin = np.unravel_index(np.argmax(ff, axis=None), ff.shape)
    
    def sample_index(p):
        i = np.random.choice(np.arange(p.size), p=p.ravel())
        return np.unravel_index(i, p.shape)
    
    backup_new_landmark_bins = np.stack([sample_index(ff) for _ in range(4)])
    backup_new_landmark_scores = np.array([ff[tuple(bin_idx)] for bin_idx in backup_new_landmark_bins])
    sort_idx = np.argsort(backup_new_landmark_scores)
    backup_new_landmark_bins = backup_new_landmark_bins[sort_idx]
    best_new_landmark_bin = np.expand_dims(best_new_landmark_bin, axis=0)
    new_landmark_bins = np.concatenate([best_new_landmark_bin, backup_new_landmark_bins], axis=0)
    new_landmarks = []
    for new_landmark_bin in new_landmark_bins:
        bin_size = np.divide((high - low), np.array(ff.shape))
        new_landmark = low + np.multiply(new_landmark_bin + 0.5, bin_size)
        new_landmarks.append(new_landmark)
    
    best_new_landmark = new_landmarks[0]
    novel_ag_arr = np.array(list(novel_ag))
    reachable, accessible, endpoint_reachable, info = accessibility_model.is_reachable_(np.tile(best_new_landmark, (len(novel_ag_arr), 1)), novel_ag_arr)
    is_reachable = np.argmax(reachable, axis=1)
    
    
    
    accessible_novel_points = novel_ag_arr[np.where(accessible)]
    not_accessible_novel_points = novel_ag_arr[np.where(1 - accessible)]
    
    n_reachable_novel_points = np.sum(is_reachable)
    n_accessible_novel_points = np.sum(accessible)
    n_endpoint_reachable_novel_points = np.sum(endpoint_reachable)
    new_reachable_prop = n_reachable_novel_points / len(all_achieved_goals)
    
    
    if new_reachable_prop < 0.001:
        return None, False
        
    print(f'Generated {len(new_landmarks)} new landmark possibilities: {new_landmarks}')
    
    if save_wip_landmarks:
        
        
        
        
        plt.figure()
        plt.contourf(x, y, ff.T)
        plt.axis('scaled')
        plt.colorbar()
        plt.savefig(os.path.join(work_dir, f'r{round_idx}_lm_{starting_landmark_idx}_c{cycle_idx}_contour_should_thresh_{should_thresh}.jpg'))
    
    return new_landmarks







# MAIN DRIVER FUNCTION

def learn_landmarks_progressively(
    work_dir,
    env: gym.Env,
    test_env: gym.Env,
    agent: Union[SACMetaAgent, PPOMetaAgent, HERMetaAgent],
    eval_step,
    algo,
    primitive_agents,
    meta_period,
    get_achieved_goal,
    get_desired_goal,
    ll_arena_low,
    ll_arena_high,
    explorable_low,
    explorable_high,
    relative_goals,
    success_threshold,
    min_expl_steps,
    video_recorder: VideoRecorder,
    cfg=None
) -> float:
    
    
    low = env.goal_low
    high = env.goal_high
    
    if 'mpe' in env.name:
        get_achieved_goal = lambda x: x[:, 2:4]
    elif 'ant' in env.name:
        get_achieved_goal = lambda x: x[:, :2]
    else:
        assert 'fetch' in env.name
        get_achieved_goal = lambda x: x[:, :3]
        
    if isinstance(env.observation_space, gym.spaces.Dict):
        landmark_low = env.observation_space['achieved_goal'].low
        landmark_high = env.observation_space['achieved_goal'].high
    else:
        obs_space_low = env.observation_space.low
        obs_space_high = env.observation_space.high
        landmark_low = get_achieved_goal(np.expand_dims(obs_space_low, axis=0)).squeeze()
        landmark_high = get_achieved_goal(np.expand_dims(obs_space_high, axis=0)).squeeze()
    
    if isinstance(env.observation_space, gym.spaces.Dict):
        obs_shape = env.observation_space['observation'].shape
    else:
        obs_shape = env.observation_space.shape
    act_space = env.action_space
    goal_space = env.goal_space
    if isinstance(act_space, Discrete):
        act_shape = (int(act_space.n),)
    else:
        act_shape = act_space.shape
    if hasattr(env, 'discrete_goal_space'):
        goal_space = env.discrete_goal_space
    else:
        goal_space = env.goal_space
    if isinstance(goal_space, Discrete):
        goal_shape = (int(goal_space.n),)
    else:
        goal_shape = goal_space.shape
    subgoal_shape = goal_shape
    
    use_double_dtype = cfg.algorithm.get("normalize_double_precision", False)
    dtype = np.double if use_double_dtype else np.float32
    rng = np.random.default_rng(seed=cfg.seed)
    
    
    if os.path.isfile(os.path.join(work_dir, 'sg_model.pth')):
        reachability_model = torch.load(os.path.join(work_dir, 'sg_model.pth'))
    else:
        reachability_model = nn.Sequential(
            nn.Linear(goal_shape[0] * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2),
            nn.LogSoftmax()
        )
        
        
    trajectory_model_path = agent.ppo_meta_agent.primitive_agents[0].skills[0].trajectory_model_path
        
    if os.path.isfile(trajectory_model_path):
        trajectory_model = torch.load(trajectory_model_path)
    else:
        raise Exception(f'Trajectory model not found at {trajectory_model_path}')
        
    n_initial_landmarks = 1
    max_rounds = 1000000
    rnd = 0
    landmarks_info = []
    
    for i in range(n_initial_landmarks):
        obs, _ = env.reset()
        initial_landmark = get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze()
        initial_landmark_info = {
            'id': i,
            'landmark': initial_landmark,
            'finished': False,
            'cycle': 0,
            'n_obs_collected': 0,
            'generated_from_landmark_id': None,
            'total_exploration_steps': 0,
            'useful_exploration_steps': 0
        }
        landmarks_info.append(initial_landmark_info)
        agent.ppo_meta_agent.add_landmark(initial_landmark_info)
    
    landmarks = [lm_info['landmark'] for lm_info in landmarks_info]
        
    accessibility_model = AccessibilityModel(
        agent,
        trajectory_model,
        low,
        high,
        landmarks_info,
        success_threshold,
        ll_arena_low,
        ll_arena_high,
        relative_goals,
        work_dir
    )
    
    if accessibility_model.saved_model_available(work_dir):
        accessibility_model.load(work_dir)
        
    
    reachability_optim = SGD(reachability_model.parameters(), lr=0.001)
    
    reachability_model_dict = {'model': reachability_model, 'optim': reachability_optim}
    
    
    success_rates_data = []
    
    
    def evaluate_fn(steps_so_far):
        _, success_rate, _ = evaluate(
            work_dir,
            test_env,
            10,
            meta_period,
            None,
            algo,
            primitive_agents,
            None,
            ll_arena_low,
            ll_arena_high,
            explorable_low,
            explorable_high,
            relative_goals,
            success_threshold,
            work_dir,
            cfg=cfg
        )
        success_rates_data.append(np.array([float(steps_so_far), success_rate]))
        with open(os.path.join(work_dir, 'success_rates.npy'), 'wb') as f:
            np.save(f, np.stack(success_rates_data))
    
    
    
    agent.reset()
    agent.ppo_meta_agent.set_reachability_models(None, None, None, reachability_model, accessibility_model, None)
    
    train_eps_per_cycle = 1
    landmark_gen_eps_per_round = 1
    n_accessibility_points = min_expl_steps * 0.
    initial_min_expl_steps = 1.5 * min_expl_steps
    
    
    n_exp_candidates = 10000
    meta_horizon = env.max_episode_steps // meta_period
    n_final_exp_points = landmark_gen_eps_per_round * meta_horizon
    no_reachable_landmarks_count = 0
    max_eps_with_no_new_lm = 10
    
    use_adaptive_sampling = False
    train_model_traj_data = {
        key: {
            'obs': [],
            'segments': [],
            'directly_reachable_obs': [],
            'obs_latest_checkpoint_idx': 0,
            'segments_latest_checkpoint_idx': 0,
            'dr_obs_latest_checkpoint_idx': 0
        } for key, val in enumerate(landmarks_info) if not val['finished']
    }
    
    
    global_steps = 0
    attempted_goals = []
    exploration_datasets = []
    exploration_histogram = None
    
    global_reachability_failures = []
    start_goals = []
    end_goals = []
    reached_labels = []
    
    # ----------------MAIN LOOP------------------
    
    while rnd < max_rounds - 1:
        print(f'**********************************************************************Round {rnd}')
        # Each round we add one new landmark to our set
            
        unf_landmarks_info = [landmark_info for landmark_info in landmarks_info if not landmark_info['finished']]


        train_model_traj_data, reachability_data, explored_from_landmarks, round_no_reachable_landmarks_count, global_steps, attempted_goals = explore(
            landmarks_info,
            unf_landmarks_info,
            train_model_traj_data,
            train_eps_per_cycle,
            work_dir,
            env,
            agent,
            reachability_model_dict,
            accessibility_model,
            trajectory_model,
            exploration_histogram,
            low,
            high,
            ll_arena_low,
            ll_arena_high,
            explorable_low,
            explorable_high,
            n_accessibility_points,
            attempted_goals,
            success_threshold,
            meta_period,
            get_achieved_goal,
            get_desired_goal,
            rnd,
            global_steps,
            eval_step,
            evaluate_fn,
            cfg
        )
        
        no_reachable_landmarks_count += round_no_reachable_landmarks_count
        
        finished_landmarks = []
        new_landmarks = []
        incremented_cycles_landmarks = []
        expl_landmark_ids = []
        
        if len(explored_from_landmarks) == 0:
            explore_from_landmark_idx = None
        else:
            explore_from_landmark_idx = np.max(list(explored_from_landmarks)) # most recent
        
        if explore_from_landmark_idx is not None:
            arr1, arr2 = [explore_from_landmark_idx], [train_model_traj_data[explore_from_landmark_idx]]
        else:
            arr1, arr2 = [], []
            
        
            
        for landmark_id, expl_data in zip(arr1, arr2):
            
            obs_start_idx = expl_data['obs_latest_checkpoint_idx']
            segments_start_idx = expl_data['segments_latest_checkpoint_idx']
            dr_obs_start_idx = expl_data['dr_obs_latest_checkpoint_idx']
            
            expl_data_since_checkpoint = {
                'obs': expl_data['obs'][obs_start_idx:],
                'segments': expl_data['segments'][segments_start_idx:],
                'directly_reachable_obs': expl_data['directly_reachable_obs'][dr_obs_start_idx:],
            }
            
            
            if landmarks_info[landmark_id]['generated_from_landmark_id'] is None:
                generate_new_lm = landmarks_info[landmark_id]['n_obs_collected'] - expl_data['obs_latest_checkpoint_idx'] >= initial_min_expl_steps * (landmarks_info[landmark_id]['cycle'] + 1)
                # generate_new_lm = landmarks_info[landmark_id]['n_obs_collected'] - expl_data['obs_latest_checkpoint_idx'] >= initial_min_expl_steps
            else:
                generate_new_lm = landmarks_info[landmark_id]['n_obs_collected'] - expl_data['obs_latest_checkpoint_idx'] >= min_expl_steps * (landmarks_info[landmark_id]['cycle'] + 1)
                # generate_new_lm = landmarks_info[landmark_id]['n_obs_collected'] - expl_data['obs_latest_checkpoint_idx'] >= min_expl_steps
            
            
            
            if generate_new_lm:
                
                round_reachability_failures = 0
                
                expl_achieved_goals = get_achieved_goal(np.stack(expl_data['obs'])).squeeze()
                
                if len(low) > 2:
                # if False:
                    accessibility_model.infer(expl_achieved_goals, explore_from_landmark_idx)
                else:
                    accessibility_model.infer(expl_achieved_goals, explore_from_landmark_idx, diagram_save_path=os.path.join(work_dir, f'r{rnd}_svm_accessibility.jpg'))
                accessibility_model.save(work_dir)
                
                
                thresh_prob = 0.98
                max_thresh_prob = 0.98
                thresh_prob_increment = 0.01
                found_reachable_landmark = False
                landmark_gen_attempt = 0
                
                while (thresh_prob <= max_thresh_prob) and (not found_reachable_landmark):
                    
                    print('*********************************************************************************** Thresh prob:', thresh_prob)
                
                
                    candidate_landmarks = generate_frontier_landmark(
                        landmark_id,e_checkpoint,
                        expl_data,
                        attempted_goals,
                        work_dir,
                        rnd,
                        landmarks_info[landmark_id]['cycle'],
                        success_threshold,
                        landmark_gen_attempt == 0,
                        reachability_model_dict,
                        accessibility_model,
                        landmarks_info,
                        explorable_low,
                        explorable_high,
                        ll_arena_low,
                        ll_arena_high,
                        meta_period,
                        thresh_prob,
                        get_achieved_goal
                    )
                    print('candidates:', candidate_landmarks)
                
                    if candidate_landmarks is None:
                        landmarks_info[landmark_id]['finished'] = True
                        finished_landmarks.append(landmark_id)
                        break                        
                    else:
                    
                        candidate_landmark = None
                    
                        for trial_landmark in candidate_landmarks:   
                            
                            candidate_landmark = trial_landmark
                    
                            n_trials = 1
                            successful_trials = 0
                            init_orientation = None
                            
                            for trial in range(n_trials):
                            
                                obs, _ = env.reset()
                                terminated = truncated = False
                                episode_step = 0
                                # latest_lm_idx = len(landmarks) - 1
                                reached_latest_landmark = False
                                count_meta_steps = False
                                meta_steps = 0
                                
                                steps_since_last_control = 0
                                desired_coord = None
                                
                                while not (terminated or truncated):
                                    
                                    meta_control_step = False
                                    
                                    if episode_step == 0:
                                        meta_control_step = True
                                        
                                    if steps_since_last_control == meta_period:
                                        meta_control_step = True
                                        
                                    achieved_goal = get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze()
                                    if desired_coord is not None and np.linalg.norm(achieved_goal - desired_coord) < success_threshold:
                                        meta_control_step = True
                                    
                                    
                                    if meta_control_step:
                                        
                                        steps_since_last_control = 0
                                        
                                        ag_origin = np.copy(get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze())
                                        print('------------------------')
                                        print('step:', episode_step)
                                        print('meta steps:', meta_steps)
                                        print('ag:', get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze())
                                        print('full obs:', obs)
                                        print('explore from landmark:', landmarks[landmark_id])
                                        print('candidate:', candidate_landmark)
                                        print('dist from explore_from:', np.linalg.norm(get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze() - landmarks[landmark_id]))
                                        print('dist from candidate:', np.linalg.norm(get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze() - candidate_landmark))
                                        if np.linalg.norm(get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze() - landmarks[landmark_id]) < success_threshold:
                                            reached_latest_landmark = True
                                            count_meta_steps = True
                                        if np.linalg.norm(get_achieved_goal(np.expand_dims(obs, axis=0)).squeeze() - candidate_landmark) < success_threshold:
                                            # Reached candidate landmark
                                            break
                                        if reached_latest_landmark:
                                            coord = True
                                            meta_action = candidate_landmark
                                        else:
                                            coord = False
                                            agent.ppo_meta_agent.set_random_actions(False)
                                            meta_action = agent.ppo_meta_agent.towards_landmark(obs, landmark_id, initial_step=episode_step == 0)
                                        if count_meta_steps:
                                            meta_steps += 1
                                            
                                        if coord:
                                            desired_coord = meta_action
                                        elif meta_action[0] == 1:
                                            desired_coord = get_desired_goal(np.expand_dims(obs, axis=0)).squeeze()
                                        else:
                                            idx = np.argmax(meta_action)
                                            desired_coord = landmarks[idx - 1]
                                            
                                        print('ma:', meta_action)
                                        print('ag origin:', ag_origin)
                                    
                                    
                                    if global_steps > 0 and global_steps % eval_step == 0:
                                        evaluate_fn(global_steps)
                                        
                                    
                                    action = agent.act(obs, meta_action, achieved_goal_origin=ag_origin, initial_orientation=init_orientation, coord=coord)
                                    obs, reward, terminated, truncated, info = env.step(action)
                                    episode_step += 1
                                    global_steps += 1
                                    steps_since_last_control += 1
                                if meta_steps <= 1:
                                    # Candidate landmark was reachable within 1 meta-step of the previous landmark, it is good enough
                                    successful_trials += 1
                                else:
                                    round_reachability_failures += 1
                                    
                            if successful_trials == n_trials:
                                new_landmarks.append(np.copy(candidate_landmark))
                                expl_landmark_ids.append(landmark_id)
                                
                                landmarks_info[landmark_id]['cycle'] = 0
                                
                                train_model_traj_data[landmark_id]['obs_latest_checkpoint_idx'] = len(train_model_traj_data[landmark_id]['obs'])
                                train_model_traj_data[landmark_id]['segments_latest_checkpoint_idx'] = len(train_model_traj_data[landmark_id]['segments'])
                                train_model_traj_data[landmark_id]['dr_obs_latest_checkpoint_idx'] = len(train_model_traj_data[landmark_id]['directly_reachable_obs'])
                                
                                # train_model_traj_data[landmark_id]['obs'] = []
                                # train_model_traj_data[landmark_id]['segments'] = []
                                # train_model_traj_data[landmark_id]['directly_reachable_obs'] = []
                                # train_model_traj_data[landmark_id]['n_obs_collected'] = 0
                                
                                found_reachable_landmark = True
                                break
                            
                    thresh_prob += thresh_prob_increment
                    landmark_gen_attempt += 1
                    
                    
                if (not found_reachable_landmark) and (not landmarks_info[landmark_id]['finished']):
                    landmarks_info[landmark_id]['candidate'] = candidate_landmark
                    landmarks_info[landmark_id]['cycle'] += 1
                    incremented_cycles_landmarks.append(landmarks_info[landmark_id])
                    
                    train_model_traj_data[landmark_id] = {
                        'obs': [],
                        'segments': [],
                        'directly_reachable_obs': [],
                        'obs_latest_checkpoint_idx': 0,
                        'segments_latest_checkpoint_idx': 0,
                        'dr_obs_latest_checkpoint_idx': 0
                    }
                    
                attempted_goals = []
                
                global_reachability_failures.append(round_reachability_failures)
                plt.figure()
                plt.plot(np.array(global_reachability_failures))
                plt.savefig(os.path.join(work_dir, 'reachability_failures.jpg'))
                
        print('*********************************************************************************************************************')
        print('*********************************************************************************************************************')
        print(f'Total steps so far: {global_steps}')
        print(f'New landmarks for round {rnd}: {new_landmarks}')
        print(f'Landmarks with incremented cycles for round {rnd}: {incremented_cycles_landmarks}')
        print(f'Finished landmarks for round {rnd}: {[landmarks_info[lm_id] for lm_id in finished_landmarks]}')
        print('Landmark cycles:', [lm_info['cycle'] for lm_info in landmarks_info])
        print('*********************************************************************************************************************')
        print('*********************************************************************************************************************')
        
        
        for new_landmark, expl_landmark_id in zip(new_landmarks, expl_landmark_ids):
            new_landmark_id = copy.copy(len(landmarks_info))
            new_landmark_info = {
                'id': new_landmark_id,
                'landmark': np.copy(new_landmark),
                'finished': False,
                'cycle': 0,
                'n_obs_collected': 0,
                'generated_from_landmark_id': expl_landmark_id,
                'total_exploration_steps': 0,
                'useful_exploration_steps': 0
            }
            landmarks_info.append(new_landmark_info)
            
            train_model_traj_data[new_landmark_id] = {
                'obs': [],
                'segments': [],
                'directly_reachable_obs': [],
                'obs_latest_checkpoint_idx': 0,
                'segments_latest_checkpoint_idx': 0,
                'dr_obs_latest_checkpoint_idx': 0
            }
            agent.ppo_meta_agent.add_landmark(new_landmark_info)
            if expl_landmark_id is not None:
                agent.ppo_meta_agent.report_landmark_reachability(expl_landmark_id, new_landmark_id)
        
        landmarks = [lm_info['landmark'] for lm_info in landmarks_info]
        
        for lm_id in finished_landmarks:
            del train_model_traj_data[lm_id]

        if all([lm_info['finished'] for lm_info in landmarks_info]):
            break
        
        if accessibility_model.is_initialized():
            agent.ppo_meta_agent.update_landmark_graph()
            
        with open(os.path.join(work_dir, 'landmarks_info.pkl'), 'wb') as file: 
            pickle.dump(landmarks_info, file)
            
        rnd += 1

    print(f'Total training steps: {global_steps}')
    print(f'Progressively learned landmarks: {landmarks}')
    return landmarks






# PLANNER


def select_action(self, state, landmark_subgoal_id=None, batched=False, evaluate=False, default_action=True):
    
    landmark_subgoal = self.landmarks[landmark_subgoal_id] if landmark_subgoal_id is not None else None
    achieved_goal = self.get_achieved_goal(np.expand_dims(state, axis=0)).squeeze()

    if not self.accessibility_model.is_initialized():
        node_idx = np.random.randint(len(self.landmarks) + 1)
        return np.int64(node_idx)
    
    if self.random_actions:
    # if True:
        action = self.action_space.sample()
        if batched:
            action = np.expand_dims(action, axis=0)
        return action
    else:
        goal = landmark_subgoal if landmark_subgoal is not None else self.get_desired_goal(np.expand_dims(state, axis=0)).squeeze()


            
        # What (if any) is the landmark we are currently at?
        at_landmark_list = [
            np.linalg.norm(
                achieved_goal - landmark
            ) < self.success_threshold for landmark in self.landmarks
        ]
        if np.sum(at_landmark_list) == 0:
            self.current_landmark = -1
        elif landmark_subgoal_id is not None and at_landmark_list[landmark_subgoal_id]:
            self.current_landmark = landmark_subgoal_id
        elif np.sum(at_landmark_list) > 1:
            indices = np.arange(len(at_landmark_list))
            possible_lms = indices[np.where(at_landmark_list)]
            # Do not choose the landmark we were already at if there's an alternative
            # (for the edge case of very close landmarks)
            possible_lms = possible_lms[possible_lms != self.current_landmark]
            # Most recently generated landmark breaks ties
            self.current_landmark = possible_lms[-1]
        else:
            self.current_landmark = np.argmax(at_landmark_list)
            
        n_landmarks = len(self.landmarks)
            
        if self.current_landmark != -1:
            
            lms = np.stack(self.landmarks)
            model_in_from = np.concatenate([np.tile(goal, (n_landmarks, 1)), lms], axis=0)
            model_in_to = np.concatenate([lms, np.tile(goal, (n_landmarks, 1))], axis=0)
            model_out = self.accessibility_model.is_reachable(model_in_from, model_in_to).squeeze()
            from_dg = model_out[:n_landmarks, 1] >= self.thresh_prob
            to_dg = model_out[n_landmarks:, 1] >= self.thresh_prob
            
            reachability_matrix = np.zeros((n_landmarks + 1, n_landmarks + 1))
            reachability_matrix[1:, 1:] = self.lm_reachability_matrix
            reachability_matrix[0, 0] = 1.
            reachability_matrix[0, 1:] = from_dg
            reachability_matrix[1:, 0] = to_dg
            nodes = [goal] + self.landmarks
            
        else:
            
            lms = np.stack(self.landmarks)
            ag_model_in_from = np.concatenate([np.tile(achieved_goal, (n_landmarks, 1)), lms], axis=0)
            ag_model_in_to = np.concatenate([lms, np.tile(achieved_goal, (n_landmarks, 1))], axis=0)
            dg_model_in_from = np.concatenate([np.tile(goal, (n_landmarks, 1)), lms], axis=0)
            dg_model_in_to = np.concatenate([lms, np.tile(goal, (n_landmarks, 1))], axis=0)
            model_in_from = np.concatenate(
                [ag_model_in_from, dg_model_in_from, np.expand_dims(achieved_goal, axis=0), np.expand_dims(goal, axis=0)],
                axis=0
            )
            model_in_to = np.concatenate(
                [ag_model_in_to, dg_model_in_to, np.expand_dims(goal, axis=0), np.expand_dims(achieved_goal, axis=0)],
                axis=0
            )
            model_out = self.accessibility_model.is_reachable(model_in_from, model_in_to).squeeze()
            ag_to_lm = model_out[:n_landmarks, 1] >= self.thresh_prob
            lm_to_ag = model_out[n_landmarks : 2 * n_landmarks, 1] >= self.thresh_prob
            dg_to_lm = model_out[2 * n_landmarks : 3 * n_landmarks, 1] >= self.thresh_prob
            lm_to_dg = model_out[3 * n_landmarks : 4 * n_landmarks, 1] >= self.thresh_prob
            ag_to_dg = model_out[4 * n_landmarks, 1] >= self.thresh_prob
            dg_to_ag = model_out[4 * n_landmarks + 1, 1] >= self.thresh_prob
            
            reachability_matrix = np.zeros((n_landmarks + 2, n_landmarks + 2))
            reachability_matrix[2:, 2:] = self.lm_reachability_matrix
            reachability_matrix[0, 0] = 1.
            reachability_matrix[1, 1] = 1.
            reachability_matrix[0, 1] = ag_to_dg
            reachability_matrix[1, 0] = dg_to_ag
            reachability_matrix[0, 2:] = ag_to_lm
            reachability_matrix[2:, 0] = lm_to_ag
            reachability_matrix[1, 2:] = dg_to_lm
            reachability_matrix[2:, 1] = lm_to_dg
            
            nodes = [achieved_goal, goal] + self.landmarks
            
            
            
        rows, cols = np.where(reachability_matrix == 1.)
        edges = zip(rows.tolist(), cols.tolist())
        gr = nx.DiGraph()
        for i, node in enumerate(nodes):
            gr.add_node(i, pos=node)
        gr.add_edges_from(edges)
        
        
        pos=nx.get_node_attributes(gr,'pos')
        

        
        graph = csr_matrix(reachability_matrix)
        start_node_idx = self.current_landmark + 1 if self.current_landmark != -1 else 0
        dist_matrix, predecessors = shortest_path(csgraph=graph, directed=True, indices=start_node_idx, return_predecessors=True)
        predecessor_idx = None
        node_idx = None
        no_path_available = False
        while predecessor_idx != -9999:
            if predecessor_idx == start_node_idx:
                break
            node_idx = predecessor_idx
            if node_idx is None:
                node_idx = 1 if self.current_landmark == -1 else 0
                predecessor_idx = predecessors[node_idx]
                if predecessor_idx == -9999:
                    no_path_available = True
                    break
            else:
                predecessor_idx = predecessors[node_idx]
                
                
        if no_path_available and default_action:
            
            # Closest landmark
            node_coords = np.stack([goal] + self.landmarks)
            current_coord = np.zeros_like(node_coords)
            current_coord[:] = achieved_goal
            node_distances = np.linalg.norm(current_coord - node_coords, axis=1)
            
            node_idx = np.argmin(node_distances)
            
        if self.current_landmark == -1 and not no_path_available:
            node_idx = node_idx - 1
            
        if node_idx == 0 and landmark_subgoal is not None:
            idx = -1
            for i, landmark in enumerate(self.landmarks):
                if np.linalg.norm(landmark_subgoal - landmark) < self.success_threshold:
                    idx = i
            if idx < 0:
                raise Exception('landmark_subgoal should be the index of a landmark')
            node_idx = idx + 1
            
            
        if no_path_available and not default_action:
            return None
        return np.int64(node_idx)
    
    
    
    
    
# FUNCTION TO UPDATE THE LANMDARK GRAPH WHEN THERE IS A NEW LANDMARK


def update_landmark_graph(self):
    nodes = []
    for landmark in self.landmarks:
        nodes.append(landmark)
    n_nodes = len(nodes)
    lm_reachability_matrix = np.zeros((len(nodes), len(nodes)))
    
    from_nodes = []
    to_nodes = []
    idx_pairs = []
    
    
    for i in range(n_nodes):
        for j in range(n_nodes):
            i_coord = self.landmarks[i]
            j_coord = self.landmarks[j]
            
            from_nodes.append(i_coord)
            to_nodes.append(j_coord)
            idx_pairs.append((i, j))
            
    from_nodes = np.stack(from_nodes)
    to_nodes = np.stack(to_nodes)
    model_out = self.accessibility_model.is_reachable(from_nodes, to_nodes, print_bb_info=False)
    
    for idx, (i, j) in enumerate(idx_pairs):
        lm_reachability_matrix[i, j] = (model_out[idx][1] >= self.thresh_prob).squeeze()
            
        if j > 0:
            generated_from_id = self.landmarks_info[j]['generated_from_landmark_id']
            if i == generated_from_id:
                lm_reachability_matrix[i, j] = 1.            
                
    # Landmark self edges
    for i in range(n_nodes):
        lm_reachability_matrix[i, i] = 1.
        
    self.lm_reachability_matrix = lm_reachability_matrix
