import torch
import numpy as np
import os
import sys

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(FILE_DIR)))
sys.path.append(ROOT_DIR)


class Planner():

    def __init__(self,
                 d_discard,
                 d_neigh,
                 tau_q_neigh,
                 tau_q_discard,
                 tau_q_goal,
                 dynamics,
                 policy_low,
                 policy_high,
                 tau_expand_q_std=0.1,
                 tau_expand_q_min=0.8,
                 action_sampler=None,
                 n_iter=100,
                 n_sim=3,
                 t_sparse=1.0,
                 t_value=1.0,
                 batch_size=64,
                 p_sample_node_sparse=0.5,
                 p_sample_node_rand=0.2,
                 p_sample_action_rand=0.0,
                 p_sample_action_model=0.5,
                 max_action=1.0,
                 discount=0.96,
                 elite_size=1,
                 use_q_std_reject=1,
                 device="cuda"):

        self.device = device
        self.d_reject = d_discard
        self.d_neigh = d_neigh
        self.tau_neigh = tau_q_neigh
        self.tau_discard = tau_q_discard
        self.tau_goal = tau_q_goal
        self.n_iter = n_iter
        self.n_sim = n_sim
        self.t_sparse = t_sparse
        self.t_value = t_value
        self.batch_size = batch_size
        self.use_q_std_reject = use_q_std_reject
        self.p_sample_node_sparse = p_sample_node_sparse
        self.p_sample_node_value = 1.0 - p_sample_node_sparse - p_sample_node_rand
        self.n_sample_node_sparse = int(p_sample_node_sparse * self.batch_size)
        self.n_sample_node_rand = int(p_sample_node_rand * self.batch_size)
        self.n_sample_node_value = int(1.0- p_sample_node_sparse - p_sample_node_rand) * self.batch_size
        self.batch_size = self.n_sample_node_sparse + self.n_sample_node_value + self.n_sample_node_rand

        self.n_sample_action_rand = int(p_sample_action_rand * self.batch_size)
        self.n_sample_action_model = int(p_sample_action_model * self.batch_size)
        self.n_sample_action_policy = self.batch_size - self.n_sample_action_rand - self.n_sample_action_model

        self.max_action = max_action
        self.min_action = -max_action
        self.action_dim = policy_low.action_dim
        self.dynamics = dynamics
        self.q_std_threshold = tau_expand_q_std
        self.q_min_threshold = tau_expand_q_min
        self.elite_size = elite_size
        self.discount = discount

        self.policy_low = policy_low
        self.n_qs = self.policy_low.n_qs
        self.policy_high = policy_high
        self.action_sampler = action_sampler

        self.value_travel = torch.zeros((1), device=self.device)
        self.value_edge = torch.zeros((1), device=self.device)
        self.neighs = torch.zeros((1), dtype=torch.int, device=self.device)
        self.parent = None
        self.value_to_go = None
        self.actions_policy = None
        self.nodes = None
        self.n_nodes = 0
        self.i_iter = 0

    def reset(self, z_init, goal=None):

        # Empty tree
        self.value_to_go = torch.ones((1), device=self.device)
        self.value_travel = torch.ones((1), device=self.device)
        self.value_edge = torch.ones((1), device=self.device)
        self.neighs = torch.zeros((1), device=self.device)
        self.parent = torch.IntTensor(np.array([-1])).to(self.device)
        self.generated_at_iter = torch.IntTensor(np.array([-1])).to(self.device)

        # Set root node
        self.nodes = z_init
        self.n_nodes = 1
        self.goal = goal
        q_init, a_init = self.policy_high.compute_q_min_t(state=z_init, goal=goal)
        self.actions_policy = a_init
        self.value_to_go = q_init.reshape(-1)
        self.i_iter = 0

    def plan(self, ):
        self.nodes_per_step = np.zeros(self.n_iter+1)
        while self.i_iter < self.n_iter:
            self.plan_step()
            self.nodes_per_step[self.i_iter] = len(self.nodes)

            # if torch.any(self.value_to_go > self.tau_goal):
            #     break

    def compute_path(self):
        # Find node with lowest cost

        idx_goal_region = torch.where(self.value_to_go > self.tau_goal)[0]
        if len(idx_goal_region) > 0:
            total_cost = self.value_travel[idx_goal_region]
            idx_best = idx_goal_region[torch.argmax(total_cost)]
        else:
            if self.elite_size <= 1:
                idx_best = torch.argmax(self.value_to_go)
            else:
                elite_idx = torch.argsort(self.value_to_go, descending=True)[:min(self.elite_size, len(self.value_to_go))]
                idx_best = elite_idx[torch.argmax(self.value_travel[elite_idx])]

        # Assemble solution
        idx = idx_best
        path = []
        idx_path = []
        while 1:
            if idx < 0:
                break
            idx_path.insert(0, idx)
            path.insert(0, self.nodes[[idx]])
            idx = self.parent[idx]

        return torch.stack(path, 0), torch.stack(idx_path)


    @torch.no_grad()
    def plan_step(self):
        self.i_iter += 1

        # -------------------- #
        # --- Sample nodes --- #
        # -------------------- #

        # Sample random the first few steps
        if self.i_iter < 10:
            idx_exp = np.random.choice(np.arange(self.n_nodes), size=self.batch_size)
        else:
            idx_set = []
            not_goal_mask = self.value_to_go < self.tau_goal
            if torch.any(not_goal_mask):
                idx_all = torch.arange(self.nodes.shape[0])[not_goal_mask]
            else:
                idx_all = torch.arange(self.nodes.shape[0])

            if self.p_sample_node_sparse > 0.0:
                # Sample based on sparsity
                max_neighs = torch.max(self.neighs[idx_all])
                p0_sparse = torch.exp((max_neighs - self.neighs[idx_all]) / self.t_sparse).clip(0.0,
                                                                                                1000000.0)  # .clip(0.0001, 1000.0)
                p_sparse = p0_sparse / torch.sum(p0_sparse)
                p_sparse_np = p_sparse.detach().cpu().numpy()
                idx_exp_sparse = np.random.choice(idx_all,  # torch version?
                                                  p=p_sparse_np, size=self.n_sample_node_sparse)
                idx_set.append(idx_exp_sparse)

            if self.p_sample_node_value > 0.0:
                # Sample based on value
                max_value = torch.max(self.value_to_go[idx_all])
                p0_value = torch.exp((self.value_to_go[idx_all] - max_value) / self.t_value) #.clip(0.0001, 1000.0)
                p_value = p0_value / torch.sum(p0_value)
                p_value_np = p_value.detach().cpu().numpy()
                idx_exp_value = np.random.choice(idx_all,  # torch version?
                                                 p=p_value_np, size=self.n_sample_node_value)
                idx_set.append(idx_exp_value)

            # Sample random node
            if self.p_sample_node_sparse+ self.p_sample_node_value < 1.0:
                idx_exp_rand = np.random.choice(idx_all, size=self.n_sample_node_rand)
                idx_set.append(idx_exp_rand)
            idx_exp = np.concatenate(idx_set, axis=0)

        # Select expansion nodes
        idx_exp = torch.from_numpy(np.random.permutation(idx_exp)).to(self.device)
        z_exp = self.nodes[idx_exp]

        # ------------------------ #
        # --- Simulate forward --- #
        # ------------------------ #

        # Generate new states
        z_new = z_exp.clone()
        for i_sim in range(self.n_sim):

            # ---------------------- #
            # --- Sample actions --- #
            # ---------------------- #

            # Actions based on learned policy
            if i_sim == 0:
                a_exp_policy = self.actions_policy[idx_exp[:self.n_sample_action_policy]]
            else:
                a_exp_policy = self.policy_high.actor(z_new[:self.n_sample_action_policy])

            # Sample actions randomly
            a_exp_rand = torch.FloatTensor(self.n_sample_action_rand, self.action_dim).uniform_(
                self.min_action, self.max_action).to(self.device)

            a_exp_model = self.action_sampler.decode(z_new[-self.n_sample_action_model:]).clip(self.min_action,
                                                                                              self.max_action)

            a_exp = torch.cat([a_exp_policy, a_exp_rand, a_exp_model], dim=0)

            # Predict displacement
            delta = self.dynamics(z_new, a_exp)
            z_new += delta

        # Predict value variance
        if self.i_iter > 5:
            qs, q_min, q_std, q_mean, action = self.policy_low.compute_q_all_t(z_exp, z_new)

            # Determine if reject
            not_reject_min = q_min.reshape(-1) > self.q_min_threshold

            if self.use_q_std_reject:
                not_reject_std = q_std.reshape(-1) < self.q_std_threshold
                not_reject = torch.logical_and(not_reject_std, not_reject_min)
            else:
                not_reject = not_reject_min

            idx_exp = idx_exp[not_reject]
            z_new = z_new[not_reject]

        # --------------------------------------- #
        # --- Reject node if not novel enough --- #
        # --------------------------------------- #

        # Compute value of new node wrt goal
        value_to_go_new_ = self.policy_high.compute_q_min_t(state=z_new)[0]
        value_to_go_new_ = value_to_go_new_.reshape(-1)

        # Rejection based on l2 distance
        dist_l2 = torch.norm(z_new.unsqueeze(1) - self.nodes, dim=-1)
        if self.i_iter > 5:
            mask_l2 = torch.logical_or(torch.all(dist_l2 > self.d_reject, dim=1), value_to_go_new_ > torch.max(self.value_to_go))

            if torch.count_nonzero(mask_l2) <= 0:
                return

            # Keep relevant nodes
            idx_exp = idx_exp[mask_l2]
            z_new = z_new[mask_l2]
            dist_l2 = dist_l2[mask_l2]
            value_to_go_new_ = value_to_go_new_[mask_l2]

        # Find neighbors based on l2
        neigh_l2 = dist_l2 < self.d_neigh
        neigh_l2[torch.arange(z_new.shape[0]), idx_exp] = True
        n_neighs_l2 = torch.sum(neigh_l2, 1)
        neigh_l2_idx = torch.where(neigh_l2)

        # Compute value from l2 neighbors to new node
        z_neigh_l2 = self.nodes[neigh_l2_idx[1]]
        z_new_repeat = torch.repeat_interleave(z_new, repeats=n_neighs_l2, dim=0)
        value_new_from_neigh, _ = self.policy_low.compute_q_min_t(state=z_neigh_l2, goal=z_new_repeat)
        value_new_from_neigh = value_new_from_neigh.reshape(-1)

        # Reject new nodes if neighbor exists that is too close
        value_new_from_nodes = torch.zeros_like(dist_l2)
        value_new_from_nodes[neigh_l2] = value_new_from_neigh
        mask_neigh_novel_new = value_new_from_nodes < self.tau_discard
        mask_value = torch.all(mask_neigh_novel_new, dim=1)
        mask_value = torch.logical_or(mask_value, value_to_go_new_ > torch.max(self.value_to_go))
        if torch.count_nonzero(mask_value) <= 0:
            return

        # Keep relevant nodes
        idx_exp = idx_exp[mask_value]
        z_new = z_new[mask_value]
        dist_l2 = dist_l2[mask_value]
        value_new_from_nodes = value_new_from_nodes[mask_value]
        n_neighs_l2_novel = torch.sum(mask_neigh_novel_new[mask_value], 1)
        neigh_l2_novel = mask_neigh_novel_new[mask_value]
        neigh_l2_novel_idx = torch.where(neigh_l2_novel)
        z_neigh_l2 = self.nodes[neigh_l2_novel_idx[1]]

        # Find neighbors nodes (transition to new node) based on value
        mask_nodes_to_new_neigh = value_new_from_nodes > self.tau_neigh

        # Ensure that expansion node is always neighbors
        mask_nodes_to_new_neigh[torch.arange(z_new.shape[0]), idx_exp] = True

        z_new_repeat = torch.repeat_interleave(z_new, repeats=n_neighs_l2_novel, dim=0)
        # print(z_neigh_l2.shape[0])
        value_new_to_neigh, _ = self.policy_low.compute_q_min_t(state=z_new_repeat, goal=z_neigh_l2)
        value_new_to_neigh = value_new_to_neigh.reshape(-1)

        value_new_to_nodes = torch.zeros_like(dist_l2)
        value_new_to_nodes[neigh_l2_novel] = value_new_to_neigh
        mask_new_to_nodes = value_new_to_nodes > self.tau_neigh

        # Compute value of new node wrt goal
        value_to_go_new, actions_policy_new = self.policy_high.compute_q_min_t(state=z_new)
        value_to_go_new = value_to_go_new.reshape(-1)

        # ------------------------------------ #
        # ------- Compute neighbors ---------- #
        # ------------------------------------ #

        # Compute neighbors of node (to which is can transition)
        n_neigh_new = torch.sum(mask_new_to_nodes, 1)

        # Update neighbors of nodes in tree
        n_add_neigh = torch.sum(mask_nodes_to_new_neigh, 0)
        self.neighs += n_add_neigh

        # Compute value of new node (root to new node)
        value_edge_new = value_new_from_nodes[torch.arange(z_new.shape[0]), idx_exp]
        value_travel_new = self.value_travel[idx_exp] * value_edge_new

        self.parent = torch.cat([self.parent, idx_exp], dim=0)
        self.value_travel = torch.cat([self.value_travel, value_travel_new], dim=0)
        self.value_edge = torch.cat([self.value_edge, value_edge_new], dim=0)
        self.nodes = torch.cat([self.nodes, z_new], dim=0)
        self.neighs = torch.cat([self.neighs, n_neigh_new])
        self.actions_policy = torch.cat([self.actions_policy, actions_policy_new], dim=0)
        self.value_to_go = torch.cat([self.value_to_go, value_to_go_new], dim=0)

        self.n_nodes = self.nodes.shape[0]
