import torch


class Controller:

    def __init__(self, planner, policy_low, policy_high, replan_every, tau_q_stop_plan, tau_q_next_wp):
        self.planner = planner
        self.policy_low = policy_low
        self.policy_high = policy_high
        self.replan_every = replan_every
        self.tau_q_stop_plan = tau_q_stop_plan
        self.tau_q_next_wp = tau_q_next_wp

        self.step_cnt = 0
        self.curr_path = []
        self.curr_wp = None
        self.last_plan, self.last_indices = None, None
        self.device = policy_low.device

    def reset(self):
        self.step_cnt = 0
        self.curr_path = []

    def get_action(self, state):
        self.step_cnt -= 1
        # print(self.policy_high.compute_q_min_t(state)[0])
        if self.policy_high.compute_q_min_t(state)[0] > self.tau_q_stop_plan:
            action = self.policy_high.actor(state)[0].detach().cpu().numpy()
            # print("AT GOAL")
            return action

        # Replan
        if self.step_cnt <= 0:
            self.planner.reset(z_init=state)
            self.planner.plan()
            self.step_cnt = self.replan_every
            path, indices = self.planner.compute_path()
            self.last_plan = path
            self.last_indices = indices
            self.curr_path = path[1:] # first element is current state
            if len(self.curr_path) > 0:
                self.curr_wp = self.curr_path[0]

        if len(self.curr_path) <=0:
            return self.policy_high.actor(state)[0].detach().cpu().numpy()

        # Switch to next wp if possible
        if len(self.curr_path) > 0:
            wp_dists_all = self.policy_low.compute_q_min_t(
                torch.repeat_interleave(state,  repeats=len(self.curr_path), dim=0), self.curr_path)[0]

            where_reached = wp_dists_all > self.tau_q_next_wp
            if torch.any(where_reached):
                wp_dists_all[where_reached] = torch.min(wp_dists_all)
                next_wp_id = torch.argmax(wp_dists_all.squeeze())
                self.curr_wp = self.curr_path[next_wp_id]
                self.curr_path = self.curr_path[next_wp_id:]

        # Get action
        action = self.policy_low.actor(state, self.curr_wp.unsqueeze(0))[0].detach().cpu().numpy()
        return action



