import torch
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import config

def _flatten_helper(T, N, _tensor):
    return _tensor.view(T * N, *_tensor.size()[2:])

# the shape of observation: batch * cpu * length
class RolloutStorage(object):
    def __init__(self, num_steps, num_processes, obs_shape, action_space,
                 recurrent_hidden_state_size,hmap_shape,can_give_up, enable_rotation, pallet_size):
        self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape)
        self.obs_sub = torch.zeros(num_steps + 1, num_processes, 100)
        self.hmap = torch.zeros(num_steps + 1, num_processes, config.channel, config.container_size[0]*10, config.container_size[1]*10)
        self.hmap_sub = torch.zeros(num_steps + 1, num_processes,  config.channel, 10, 10)
        self.box = torch.zeros(num_steps + 1, num_processes, 120, 4)
        self.recurrent_hidden_states = torch.zeros(
            num_steps + 1, num_processes, recurrent_hidden_state_size)
        self.recurrent_hidden_states_sub = torch.zeros(
            num_steps + 1, num_processes, recurrent_hidden_state_size)
        self.rewards = torch.zeros(num_steps, num_processes, 1)
        self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
        self.value_preds_sub = torch.zeros(num_steps + 1, num_processes, 1)
        self.returns = torch.zeros(num_steps + 1, num_processes, 1)
        self.returns_sub = torch.zeros(num_steps + 1, num_processes, 1)
        self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
        self.action_log_probs_sub = torch.zeros(num_steps, num_processes, 1)
        self.env_loss = torch.zeros(num_steps, num_processes, 1)
        if action_space.__class__.__name__ == 'Discrete':
            action_shape = 1
        else:
            action_shape = action_space.shape[0]
        self.actions = torch.zeros(num_steps, num_processes, action_shape)
        self.actions_sub = torch.zeros(num_steps, num_processes, action_shape)
        if action_space.__class__.__name__ == 'Discrete':
            self.actions = self.actions.long()
            self.actions_sub = self.actions.long()
        self.masks = torch.ones(num_steps + 1, num_processes, 1)
        if enable_rotation:
            self.location_masks = torch.zeros(num_steps+1, num_processes, 2 * pallet_size)
        # elif can_give_up:
        #     self.location_masks = torch.zeros(num_steps+1, num_processes, pallet_size +1)
        else:
            self.location_masks = torch.zeros(num_steps+1, num_processes, pallet_size)
        self.location_masks_sub = torch.ones(num_steps+1, num_processes, 100)

        # Masks that indicate whether it's a true terminal state
        # or time limit end state
        self.bad_masks = torch.ones(num_steps + 1, num_processes, 1)
        self.num_steps = num_steps
        self.step = 0
        
    def to(self, device):
        self.obs = self.obs.to(device)
        self.obs_sub = self.obs_sub.to(device)
        self.hmap = self.hmap.to(device)
        self.recurrent_hidden_states = self.recurrent_hidden_states.to(device)
        self.box = self.box.to(device)
        self.recurrent_hidden_states_sub = self.recurrent_hidden_states_sub.to(device)
        self.rewards = self.rewards.to(device)
        self.value_preds = self.value_preds.to(device)
        self.returns = self.returns.to(device)
        self.returns_sub = self.returns_sub.to(device)
        self.action_log_probs = self.action_log_probs.to(device)
        self.env_loss = self.env_loss.to(device)
        self.actions = self.actions.to(device)
        self.masks = self.masks.to(device)
        self.bad_masks = self.bad_masks.to(device)
        self.location_masks = self.location_masks.to(device)
        self.location_masks_sub = self.location_masks_sub.to(device)
        self.actions_sub = self.actions_sub.to(device)
        self.hmap_sub = self.hmap_sub.to(device)
        self.recurrent_hidden_states_sub = self.recurrent_hidden_states_sub.to(device)
        self.action_log_probs_sub = self.action_log_probs_sub.to(device)
        self.value_preds_sub = self.value_preds_sub.to(device)

    def insert_sub(self, obs_sub, hmap_sub, recurrent_hidden_states_sub, actions_sub, 
                    action_log_probs_sub, value_preds_sub, location_masks_sub):
        self.obs_sub[self.step + 1].copy_(obs_sub)
        self.hmap_sub[self.step + 1].copy_(hmap_sub)
        self.recurrent_hidden_states_sub[self.step + 1].copy_(recurrent_hidden_states_sub)
        self.actions_sub[self.step].copy_(actions_sub)
        self.action_log_probs_sub[self.step].copy_(action_log_probs_sub)
        self.value_preds_sub[self.step].copy_(value_preds_sub)
        self.location_masks_sub[self.step + 1].copy_(location_masks_sub)

    def insert(self, obs, hmap, recurrent_hidden_states, actions, action_log_probs,
               env_loss, value_preds, masks, bad_masks, location_masks):
        self.obs[self.step + 1].copy_(obs)
        self.hmap[self.step + 1].copy_(hmap)
        self.recurrent_hidden_states[self.step + 1].copy_(recurrent_hidden_states)
        self.actions[self.step].copy_(actions)
        self.action_log_probs[self.step].copy_(action_log_probs)  
        self.value_preds[self.step].copy_(value_preds)
        # self.rewards[self.step].copy_(rewards)
        self.masks[self.step + 1].copy_(masks)
        self.bad_masks[self.step + 1].copy_(bad_masks)
        self.location_masks[self.step + 1].copy_(location_masks)
        # box = self.hmap[self.step + 1,:,1:5,0,0]
        # self.box[self.step + 2,:,:-1] = self.box[self.step+1,:,1:]
        # self.box[self.step + 2,:,-1] = box

        self.step = (self.step + 1) % self.num_steps
        

    # def insert_box(self, box):
    #     self.box[self.step + 1,:,:-1] = self.box[self.step,:,1:]
    #     self.box[self.step + 1,:,-1] = box

    def after_update(self):
        self.obs[0].copy_(self.obs[-1])
        self.obs_sub[0].copy_(self.obs_sub[-1])
        self.hmap[0].copy_(self.hmap[-1])
        self.box[0].copy_(self.box[-1])
        self.hmap_sub[0].copy_(self.hmap_sub[-1])
        self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1])
        self.recurrent_hidden_states_sub[0].copy_(self.recurrent_hidden_states_sub[-1])
        self.masks[0].copy_(self.masks[-1])
        self.bad_masks[0].copy_(self.bad_masks[-1])
        self.location_masks[0].copy_(self.location_masks[-1])
        self.location_masks_sub[0].copy_(self.location_masks_sub[-1])

    def compute_returns(self,
                        next_value,
                        use_gae,
                        gamma,
                        gae_lambda,
                        use_proper_time_limits=True):
        if use_proper_time_limits:
            if use_gae:
                self.value_preds[-1] = next_value
                gae = 0
                for step in reversed(range(self.rewards.size(0))):
                    delta = self.rewards[step] + gamma * self.value_preds[
                        step + 1] * self.masks[step +
                                               1] - self.value_preds[step]
                    gae = delta + gamma * gae_lambda * self.masks[step +
                                                                  1] * gae
                    gae = gae * self.bad_masks[step + 1]
                    self.returns[step] = gae + self.value_preds[step]
            else:
                self.returns[-1] = next_value
                for step in reversed(range(self.rewards.size(0))):
                    self.returns[step] = (self.returns[step + 1] * \
                        gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \
                        + (1 - self.bad_masks[step + 1]) * self.value_preds[step]
        else:
            if use_gae:
                self.value_preds[-1] = next_value
                gae = 0
                for step in reversed(range(self.rewards.size(0))):
                    delta = self.rewards[step] + gamma * self.value_preds[
                        step + 1] * self.masks[step +
                                               1] - self.value_preds[step]
                    gae = delta + gamma * gae_lambda * self.masks[step +
                                                                  1] * gae
                    self.returns[step] = gae + self.value_preds[step]
            else:
                self.returns[-1] = next_value#
                for step in reversed(range(self.rewards.size(0))):
                    self.returns[step] = self.returns[step + 1] * \
                        gamma * self.masks[step + 1] + self.rewards[step]
    
    def compute_returns_sub(self,
                        next_value_sub,
                        use_gae,
                        gamma,
                        gae_lambda,
                        use_proper_time_limits=True):
        if use_proper_time_limits:
            if use_gae:
                self.value_preds_sub[-1] = next_value_sub
                gae = 0
                for step in reversed(range(self.rewards.size(0))):
                    delta = self.rewards[step] + gamma * self.value_preds_sub[
                        step + 1] * self.masks[step +
                                               1] - self.value_preds[step]
                    gae = delta + gamma * gae_lambda * self.masks[step +
                                                                  1] * gae
                    gae = gae * self.bad_masks[step + 1]
                    self.returns_sub[step] = gae + self.value_preds[step]
            else:
                self.returns_sub[-1] = next_value_sub
                for step in reversed(range(self.rewards.size(0))):
                    self.returns_sub[step] = (self.returns_sub[step + 1] * \
                        gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \
                        + (1 - self.bad_masks[step + 1]) * self.value_preds[step]
        else:
            if use_gae:
                self.value_preds_sub[-1] = next_value_sub
                gae = 0
                for step in reversed(range(self.rewards.size(0))):
                    delta = self.rewards[step] + gamma * self.value_preds_sub[
                        step + 1] * self.masks[step +
                                               1] - self.value_preds_sub[step]
                    gae = delta + gamma * gae_lambda * self.masks[step +
                                                                  1] * gae
                    self.returns_sub[step] = gae + self.value_preds_sub[step]
            else:
                self.returns_sub[-1] = next_value_sub#
                for step in reversed(range(self.rewards.size(0))):
                    self.returns_sub[step] = (self.returns_sub[step + 1] * \
                        gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \
                        + (1 - self.bad_masks[step + 1]) * self.value_preds_sub[step]

    # def feed_forward_generator(self,
    #                            advantages,
    #                            num_mini_batch=None,
    #                            mini_batch_size=None):
    #     num_steps, num_processes = self.rewards.size()[0:2]
    #     batch_size = num_processes * num_steps

    #     if mini_batch_size is None:
    #         assert batch_size >= num_mini_batch, (
    #             "PPO requires the number of action_sub ({}) "
    #             "* number of steps ({}) = {} "
    #             "to be greater than or equal to the number of PPO mini batches ({})."
    #             "".format(num_processes, num_steps, num_processes * num_steps,
    #                       num_mini_batch))
    #         mini_batch_size = batch_size // num_mini_batch
    #     sampler = BatchSampler(
    #         SubsetRandomSampler(range(batch_size)),
    #         mini_batch_size,
    #         drop_last=True)
    #     for indices in sampler:
    #         obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
    #         recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(
    #             -1, self.recurrent_hidden_states.size(-1))[indices]
    #         actions_batch = self.actions.view(-1,
    #                                           self.actions.size(-1))[indices]
    #         value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
    #         return_batch = self.returns[:-1].view(-1, 1)[indices]
    #         masks_batch = self.masks[:-1].view(-1, 1)[indices]
    #         old_action_log_probs_batch = self.action_log_probs.view(-1,
    #                                                                 1)[indices]
    #         if advantages is None:
    #             adv_targ = None
    #         else:
    #             adv_targ = advantages.view(-1, 1)[indices]

    #         yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
    #             value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ

    # def recurrent_generator(self, advantages, num_mini_batch):
    #     num_processes = self.rewards.size(1)
    #     assert num_processes >= num_mini_batch, (
    #         "PPO requires the number of processes ({}) "
    #         "to be greater than or equal to the number of "
    #         "PPO mini batches ({}).".format(num_processes, num_mini_batch))
    #     num_envs_per_batch = num_processes // num_mini_batch
    #     perm = torch.randperm(num_processes)
    #     for start_ind in range(0, num_processes, num_envs_per_batch):
    #         obs_batch = []
    #         recurrent_hidden_states_batch = []
    #         actions_batch = []
    #         value_preds_batch = []
    #         return_batch = []
    #         masks_batch = []
    #         old_action_log_probs_batch = []
    #         adv_targ = []

    #         for offset in range(num_envs_per_batch):
    #             ind = perm[start_ind + offset]
    #             obs_batch.append(self.obs[:-1, ind])
    #             recurrent_hidden_states_batch.append(
    #                 self.recurrent_hidden_states[0:1, ind])
    #             actions_batch.append(self.actions[:, ind])
    #             value_preds_batch.append(self.value_preds[:-1, ind])
    #             return_batch.append(self.returns[:-1, ind])
    #             masks_batch.append(self.masks[:-1, ind])
    #             old_action_log_probs_batch.append(
    #                 self.action_log_probs[:, ind])
    #             adv_targ.append(advantages[:, ind])

    #         T, N = self.num_steps, num_envs_per_batch
    #         # These are all tensors of size (T, N, -1)
    #         obs_batch = torch.stack(obs_batch, 1)
    #         actions_batch = torch.stack(actions_batch, 1)
    #         value_preds_batch = torch.stack(value_preds_batch, 1)
    #         return_batch = torch.stack(return_batch, 1)
    #         masks_batch = torch.stack(masks_batch, 1)
    #         old_action_log_probs_batch = torch.stack(
    #             old_action_log_probs_batch, 1)
    #         adv_targ = torch.stack(adv_targ, 1)

    #         # States is just a (N, -1) tensor
    #         recurrent_hidden_states_batch = torch.stack(
    #             recurrent_hidden_states_batch, 1).view(N, -1)

    #         # Flatten the (T, N, ...) tensors to (T * N, ...)
    #         obs_batch = _flatten_helper(T, N, obs_batch)
    #         actions_batch = _flatten_helper(T, N, actions_batch)
    #         value_preds_batch = _flatten_helper(T, N, value_preds_batch)
    #         return_batch = _flatten_helper(T, N, return_batch)
    #         masks_batch = _flatten_helper(T, N, masks_batch)
    #         old_action_log_probs_batch = _flatten_helper(T, N, \
    #                 old_action_log_probs_batch)
    #         adv_targ = _flatten_helper(T, N, adv_targ)

    #         yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
    #             value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ
