import os
import numpy as np


class PolicyLowData:

    def __init__(self, data_dir, exp_dir, add_neg_noise_samples=False):
        self.z_all = np.load(os.path.join(exp_dir, "z_all.npy"), allow_pickle=True)
        self.actions_all = np.load(os.path.join(data_dir, "actions_all.npy"), allow_pickle=True)
        self.state_dim = self.z_all[0][0].shape[-1]
        self.action_dim = self.actions_all[0][0].shape[-1]
        self.add_neg_noise_samples = add_neg_noise_samples
        self.z_abs_diff = np.load(os.path.join(exp_dir, "z_stats.npy"), allow_pickle=True)[0]
        self.n_context = len(self.z_all)
        self.n_traj_per_context = np.array([len(self.z_all[i]) for i in range(self.n_context)])
        self.traj_length = np.array([[len(t) for t in self.z_all[i]] for i in range(self.n_context)])
        self.rewards_all = np.load(os.path.join(data_dir, "rewards_all.npy"), allow_pickle=True)

    def sample_batch(self, batch_size):
        goal_terminal = 0.3
        goal_future = 0.3
        n_goal_terminal = int(goal_terminal * batch_size)
        n_goal_future = int(goal_future * batch_size)
        n_goal_other = batch_size - n_goal_terminal - n_goal_future

        context_ids = np.random.randint(0, self.n_context, batch_size)
        traj_ids = np.random.randint(0, self.n_traj_per_context[context_ids], batch_size)
        step_ids = np.random.randint(np.zeros(batch_size), self.traj_length[context_ids,traj_ids] - 1)

        state = []
        state_next = []
        goals = []
        action = []
        reward = np.zeros(batch_size, dtype=np.float32)
        done = np.zeros(batch_size).astype(np.bool_)
        exp_traj = []
        for i in range(batch_size):
            state.append(self.z_all[context_ids[i]][traj_ids[i]][step_ids[i]].copy())
            state_next.append(self.z_all[context_ids[i]][traj_ids[i]][step_ids[i] + 1].copy())
            action.append(self.actions_all[context_ids[i]][traj_ids[i]][step_ids[i]].copy())
            exp_traj.append(np.any(self.rewards_all[context_ids[i]][traj_ids[i]]))

        # For terminal traj
        reward[0:n_goal_terminal] = 1.0
        done[0:n_goal_terminal] = True
        for i in range(n_goal_terminal):
            goals.append(self.z_all[context_ids[i]][traj_ids[i]][step_ids[i] + 1].copy())

        # Sample goal from future within same trajectory
        goal_future_ids = np.random.randint(step_ids[n_goal_terminal:n_goal_terminal + n_goal_future] + 2,
                                            step_ids[n_goal_terminal:n_goal_terminal + n_goal_future] + 30)

        for i in range(n_goal_future):
            goal_future_id = goal_future_ids[i].clip(max=self.traj_length[context_ids[n_goal_terminal+i]]
                                                         [traj_ids[n_goal_terminal+i]]-1)
            goals.append(self.z_all[context_ids[n_goal_terminal+i]]
                         [traj_ids[n_goal_terminal + i]][goal_future_id].copy())

        # Sample goal random from other trajectory in same context
        goal_rand_context_ids = context_ids[n_goal_terminal+n_goal_future:]
        goal_rand_traj_ids = np.random.randint(0, self.n_traj_per_context[goal_rand_context_ids])
        goal_rand_step_ids = np.random.randint(np.zeros(n_goal_other), self.traj_length[goal_rand_context_ids,
                                                                                        goal_rand_traj_ids])
        for i in range(n_goal_other):
            goals.append(self.z_all[goal_rand_context_ids[i]][goal_rand_traj_ids[i]][goal_rand_step_ids[i]].copy())

        if self.add_neg_noise_samples:
            # Add noise to negative goal samples
            gaussian_noise = np.random.normal(np.zeros((n_goal_other, self.state_dim)),
                                             2*self.z_abs_diff*np.ones((n_goal_other, self.state_dim)))

            goals += (np.array(goals[-n_goal_other:]).copy() + gaussian_noise).tolist()
            state += state[-n_goal_other:]
            state_next += state_next[-n_goal_other:]
            action += action[-n_goal_other:]
            done = np.concatenate([done, done[-n_goal_other:].copy()], 0)
            reward = np.concatenate([reward, reward[-n_goal_other:].copy()], 0)

        batch = {
            "state": np.array(state, dtype=np.float32),
            "goal": np.array(goals, dtype=np.float32),
            "state_next": np.array(state_next, dtype=np.float32),
            "reward": reward.reshape((-1,1)),
            "done": done.reshape((-1,1)),
            "action": np.array(action, dtype=np.float32),
            "exp_traj": np.array(exp_traj, dtype=np.bool)
        }

        return batch