import torch
from torch.utils.data.sampler import BatchSampler, SequentialSampler, SubsetRandomSampler


class RolloutStorage:

    def __init__(self, num_envs, num_transitions_per_env, obs_shape, states_shape, actions_shape, device='cpu', sampler='sequential'):

        self.device = device
        self.sampler = sampler

        # Core
        self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
        self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)

        self.success_observations = torch.zeros(num_transitions_per_env * num_envs, *obs_shape, device=self.device)
        self.success_actions = torch.zeros(num_transitions_per_env * num_envs, *actions_shape, device=self.device)

        # For PPO
        self.num_transitions_per_env = num_transitions_per_env
        self.num_envs = num_envs

        self.step = torch.tensor([0 for _ in range(num_envs)], dtype=torch.long, device=self.device).unsqueeze(0)
        self.success_step = 0
        self.full_fill = False

    def add_transitions(self, observations, states, actions, rewards, dones, success_id):
        for i in range(self.num_envs):
            self.observations[self.step[0, i]].copy_(observations[i])
            self.actions[self.step[0, i]].copy_(actions[i])

        self.step += 1

        for id in success_id:
            if self.success_step + self.step[0, id] >= self.num_transitions_per_env * self.num_envs:
                success_traj_length = self.num_transitions_per_env * self.num_envs - self.success_step - 1
            else:
                success_traj_length = self.step[0, id]

            self.success_observations[self.success_step:self.success_step + success_traj_length] = self.observations[0:success_traj_length, id].clone()
            self.success_actions[self.success_step:self.success_step + success_traj_length] = self.actions[0:success_traj_length, id].clone()

            self.success_step += self.step[0, id]
            self.step[0, id] = 0

            if self.success_step >= self.num_transitions_per_env * self.num_envs:
                self.success_step = 0
                self.full_fill = True

        self.step = self.step % self.num_transitions_per_env

    def clear(self):
        self.step = 0

    def get_statistics(self):
        done = self.dones.cpu()
        done[-1] = 1
        flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
        done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
        trajectory_lengths = (done_indices[1:] - done_indices[:-1])
        return trajectory_lengths.float().mean(), self.rewards.mean()

    def mini_batch_generator(self, mini_batch_size):
        if not self.full_fill:
            batch_size = self.success_step
        else:
            batch_size = self.num_transitions_per_env * self.num_envs

        mini_batch_size = int(mini_batch_size)

        subset = SubsetRandomSampler(range(batch_size))

        batch = BatchSampler(subset, mini_batch_size, drop_last=True)
        return batch
