import os
import numpy as np


class PolicyHighData(object):

    def __init__(self, data_dir, exp_dir, goal_conditioned=False):
        self.goal_conditioned = goal_conditioned

        z_all = np.load(os.path.join(exp_dir, "z_all.npy"), allow_pickle=True)
        z_all = np.concatenate(z_all)
        traj_len = np.array([len(t) for t in z_all])
        actions_all = np.load(os.path.join(data_dir, "actions_all.npy"), allow_pickle=True)
        rewards_all = np.load(os.path.join(data_dir, "rewards_all.npy"), allow_pickle=True)
        actions_all = np.concatenate(actions_all)
        rewards_all = np.concatenate(rewards_all)
        if goal_conditioned:
            self.goals_all = np.load(os.path.join(data_dir, "goals_all.npy"), allow_pickle=True)

        self.z_all = z_all
        self.rewards_all = rewards_all
        self.action_dim = actions_all[0].shape[-1]
        self.traj_length = traj_len
        self.actions_all = actions_all
        self.goal_dim = self.goals_all[0].shape[-1] if goal_conditioned else 0
        self.state_dim = z_all[0].shape[-1]
        self.action_dim = actions_all[0].shape[-1]
        self.n_traj = len(z_all)
        self.traj_len = traj_len

    def sample_batch(self, batch_size):
        traj_ids = np.random.randint(0, self.n_traj, batch_size)
        step_ids = np.random.randint(np.zeros(batch_size), self.traj_len[traj_ids]-1)

        state = []
        state_next = []
        action = []
        goal = []
        reward = []
        done = []
        exp_traj = []
        for i in range(batch_size):
            state.append(self.z_all[traj_ids[i]][step_ids[i]].copy())
            state_next.append(self.z_all[traj_ids[i]][step_ids[i] + 1].copy())
            action.append(self.actions_all[traj_ids[i]][step_ids[i]].copy())
            reward.append(self.rewards_all[traj_ids[i]][step_ids[i]].copy())
            done.append(self.rewards_all[traj_ids[i]][step_ids[i]].copy())
            if self.goal_conditioned:
                goal.append(self.goals_all[traj_ids[i]][step_ids[i]].copy())

            exp_traj.append(np.any(self.rewards_all[traj_ids[i]]))

        batch = {
            "state": np.array(state, dtype=np.float32),
            "goal": np.array(goal, dtype=np.float32),
            "state_next": np.array(state_next, dtype=np.float32),
            "reward": np.array(reward, dtype=np.float32).reshape((-1,1)),
            "done": np.array(done, dtype=np.float32).reshape((-1,1)),
            "action": np.array(action, dtype=np.float32),
            "exp_traj": np.array(exp_traj, dtype=np.bool),
        }

        if self.goal_conditioned:
            batch["goals"] = np.array(goal, dtype=np.float32)

        return batch