import os
import numpy as np


class EncoderTrainData:

    def __init__(self, data_dir, n_step_prediction=1):
        # Load observation data
        obs_all = np.load(os.path.join(data_dir, "obs_all.npy"), allow_pickle=True)
        n_context = len(obs_all)
        n_traj_per_context = len(obs_all[0])
        context_to_traj = np.arange(n_traj_per_context*n_context).reshape((n_context, -1))
        traj_to_context = np.arange(n_context).repeat(n_traj_per_context, 0)
        obs_all = np.concatenate(obs_all)
        obs_all = [seq.transpose(0, 3, 1, 2) for seq in obs_all]
        traj_length = np.array([len(t) for t in obs_all])

        # Load proprioceptive data
        proprioception_all = np.load(os.path.join(data_dir, "proprioception_all.npy"), allow_pickle=True)
        proprioception_all = np.concatenate(proprioception_all)

        # Load proprioceptive data
        action_all = np.load(os.path.join(data_dir, "actions_all.npy"), allow_pickle=True)
        action_all = np.concatenate(action_all)

        reward_all = np.load(os.path.join(data_dir, "rewards_all.npy"), allow_pickle=True)
        self.reward_all = np.concatenate(reward_all)

        self.valid_traj = traj_length > n_step_prediction
        self.context_to_traj = context_to_traj
        self.traj_to_context = traj_to_context
        self.n_step_prediction = n_step_prediction
        self.obs_all = obs_all
        self.action_dim = action_all[0].shape[-1]
        self.traj_length = traj_length
        self.prop_all = proprioception_all
        self.action_all = action_all

        self.n_traj = len(obs_all)
        self.ids_traj_valid = np.where(self.valid_traj)[0]
        self.traj_len = np.array([len(self.obs_all[i] - 1) for i in range(self.n_traj)])
        self.prop_dim = proprioception_all[0].shape[-1]
        if self.prop_dim > 5:
            self.prop_dim = min(4, self.prop_dim)

        context_to_traj_valid  = [[] for _ in range(self.context_to_traj.shape[0])]
        for c_id in range(context_to_traj.shape[0]):
            for t_id in context_to_traj[c_id]:
                if t_id in self.ids_traj_valid:
                    context_to_traj_valid[c_id].append(t_id)
        self.context_to_traj_valid = context_to_traj_valid

        # For visualization purpose
        proprioception_all_flat = np.concatenate(proprioception_all, axis=0)
        self.max_x = np.max(proprioception_all_flat[:, 0])
        self.max_y = np.max(proprioception_all_flat[:, 1])
        self.min_x = np.min(proprioception_all_flat[:, 0])
        self.min_y = np.min(proprioception_all_flat[:, 1])

        # For sampling
        self.p_goal_terminal = 0.3
        self.p_goal_future = 0.3

        self.exp_traj = np.array([np.any(r) for r in self.reward_all])

    def sample_batch(self, batch_size):

        n_goal_terminal = int(self.p_goal_terminal * batch_size)
        n_goal_future = int(self.p_goal_future * batch_size)
        n_goal_other = batch_size - n_goal_terminal - n_goal_future

        traj_ids = np.random.choice(self.ids_traj_valid, batch_size)
        step_ids = np.random.randint(np.zeros(batch_size), self.traj_len[traj_ids] - self.n_step_prediction)

        obs_goal = []
        prop_goal = []
        reward_low = np.zeros(batch_size, dtype=np.float32)
        terminal_low = np.zeros(batch_size).astype(np.bool_)

        # Collect state data for n steps to train multi step dynamics
        obs_n_step = []
        prop_n_step = []
        for j in range(self.n_step_prediction + 1):
            obs_step = []
            prop_step = []
            for i in range(batch_size):
                obs_step.append(self.obs_all[traj_ids[i]][step_ids[i] + j].copy())
                prop_step.append(self.prop_all[traj_ids[i]][step_ids[i] + j][:self.prop_dim].copy())
            obs_n_step.append(obs_step)
            prop_n_step.append(prop_step)

        action_n_step = []
        for j in range(self.n_step_prediction):
            action_step = []
            for i in range(batch_size):
                action_step.append(self.action_all[traj_ids[i]][step_ids[i] + j].copy())
            action_n_step.append(action_step)

        # Sample terminal transitions
        reward_low[0:n_goal_terminal] = 1.0
        terminal_low[0:n_goal_terminal] = True
        for i in range(n_goal_terminal):
            obs_goal.append(self.obs_all[traj_ids[i]][step_ids[i] + 1].copy())
            prop_goal.append(self.prop_all[traj_ids[i]][step_ids[i] + 1][:self.prop_dim].copy())

        # Sample goal from future within same trajectory
        goal_future_ids = np.random.randint(step_ids[n_goal_terminal:n_goal_terminal + n_goal_future] + 2,
                                            step_ids[n_goal_terminal:n_goal_terminal + n_goal_future] + 30)

        goal_future_ids = goal_future_ids.clip(
            max=self.traj_len[traj_ids][n_goal_terminal:n_goal_terminal + n_goal_future] - 1)

        for i in range(n_goal_future):
            obs_goal.append(self.obs_all[traj_ids[n_goal_terminal + i]][goal_future_ids[i]].copy())
            prop_goal.append(self.prop_all[traj_ids[n_goal_terminal + i]][goal_future_ids[i]][:self.prop_dim].copy())

        # Sample goal random from other trajectory but within same context
        contexts = self.traj_to_context[traj_ids]
        goal_rand_traj_ids = [np.random.choice(self.context_to_traj_valid[c]) for c in contexts]
        goal_rand_step_ids = np.random.randint(np.zeros(batch_size), self.traj_len[goal_rand_traj_ids])
        for i in range(n_goal_other):
            obs_goal.append(self.obs_all[goal_rand_traj_ids[i]][goal_rand_step_ids[i]].copy())
            prop_goal.append(self.prop_all[goal_rand_traj_ids[i]][goal_rand_step_ids[i]][:self.prop_dim].copy())

        # Sample states from same context
        obs_same_context = []
        prop_same_context = []
        contexts = self.traj_to_context[traj_ids]
        traj_ids_context = [np.random.choice(self.context_to_traj_valid[c]) for c in contexts]
        step_ids_context = np.random.randint(np.zeros(batch_size), self.traj_len[traj_ids_context])

        for i in range(batch_size):
            obs_same_context.append(self.obs_all[traj_ids_context[i]][step_ids_context[i]].copy())
            prop_same_context.append(self.prop_all[traj_ids_context[i]][step_ids_context[i]][:self.prop_dim].copy())

        # Get environment reward
        reward_high = []
        for i in range(batch_size):
            reward_high.append((self.reward_all[traj_ids[i]][step_ids[i]+self.n_step_prediction-1].copy()))

        exp_traj = self.exp_traj[traj_ids].copy()

        batch = {
            "obs_n_step": np.array(obs_n_step, dtype=np.float32) / 255.0,
            "obs_goal": np.array(obs_goal, dtype=np.float32) / 255.0,
            "obs_same_context": np.array(obs_same_context, dtype=np.float32) / 255.0,
            "action_n_step": np.array(action_n_step, dtype=np.float32),
            "prop_n_step": np.array(prop_n_step, dtype=np.float32),
            "prop_goal": np.array(prop_goal, dtype=np.float32),
            "prop_same_context": np.array(prop_same_context, dtype=np.float32),
            "reward_high": np.array(reward_high, dtype=np.float32).reshape(-1, 1),
            "done_high": np.array(reward_high, dtype=np.bool).reshape(-1, 1),
            "reward_low": np.array(reward_low, dtype=np.float32).reshape(-1,1),
            "done_low": np.array(terminal_low, dtype=np.bool).reshape(-1,1),
            "exp_traj": np.array(exp_traj, dtype=np.bool),
        }
        return batch


class EncoderEvalData:
    def __init__(self, data_dir):

        # Load observation data
        obs_all = np.load(os.path.join(data_dir, "obs_all.npy"), allow_pickle=True)
        obs_all = [[seq.transpose(0, 3, 1, 2) for seq in img_context] for img_context in obs_all]

        # Load proprioceptive data
        self.prop_all  = np.load(os.path.join(data_dir, "proprioception_all.npy"), allow_pickle=True)
        self.obs_all = obs_all
        self.n_contexts = len(obs_all)

        prop_all_flat = np.concatenate(np.concatenate(self.prop_all , axis=0), axis=0)
        self.max_x = np.max(prop_all_flat[:, 0])
        self.max_y = np.max(prop_all_flat[:, 1])
        self.min_x = np.min(prop_all_flat[:, 0])
        self.min_y = np.min(prop_all_flat[:, 1])

    def sample_context_data(self, c_id=None):
        context_id = np.random.randint(self.n_contexts) if c_id is None else c_id
        data = {"obs": np.concatenate(self.obs_all[context_id]).astype(np.float32) / 255.0,
                "prop": np.concatenate(self.prop_all[context_id]).astype(np.float32)}
        return data