import torch
from torch import nn, Tensor


def compute_returns(self, rewards, values, dones, last_values, gamma, lam):
    advantage = 0
    returns = torch.zeros_like(values)
    for step in reversed(range(self.num_transitions_per_env)):
        if step == self.num_transitions_per_env - 1:
            next_values = last_values
        else:
            next_values = values[step + 1]
        next_is_not_terminal = 1.0 - dones[step].float()
        delta = rewards[step] + next_is_not_terminal * gamma * next_values - values[step]
        advantage = delta + next_is_not_terminal * gamma * lam * advantage
        returns[step] = advantage + values[step]

    # Compute and normalize the advantages
    advantages = returns - values
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
class RolloutStorage(nn.Module):

    def __init__(self, num_envs, num_transitions_per_env, device='cuda'):
        
        super().__init__()

        self.device = device

        self.num_transitions_per_env = num_transitions_per_env
        self.num_envs = num_envs

        # rnn
        # self.saved_hidden_states_a = None
        # self.saved_hidden_states_c = None

        self.step = 0
        self.stored_keys = list()
        
    def register_key(self, key: str, shape=(), dtype=torch.float):
        assert not hasattr(self, key), key
        assert isinstance(shape, (list, tuple)), "shape must be a list or tuple"
        buffer = torch.zeros(
            (self.num_transitions_per_env, self.num_envs) + shape, dtype=dtype, device=self.device
        )
        self.register_buffer(key, buffer, persistent=False)
        self.stored_keys.append(key)
    
    def increment_step(self):
        self.step += 1

    def update_key(self, key: str, data: Tensor):
        assert not data.requires_grad
        assert self.step < self.num_transitions_per_env, "Rollout buffer overflow"
        getattr(self, key)[self.step].copy_(data)
        
    def batch_update_data(self, key: str, data: Tensor):
        assert not data.requires_grad
        getattr(self, key)[:] = data
        # self.store_dict[key] += self.total_sum()

    def _save_hidden_states(self, hidden_states):
        assert NotImplementedError
        if hidden_states is None or hidden_states==(None, None):
            return
        # make a tuple out of GRU hidden state sto match the LSTM format
        hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
        hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)

        # initialize if needed 
        if self.saved_hidden_states_a is None:
            self.saved_hidden_states_a = [torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))]
            self.saved_hidden_states_c = [torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))]
        # copy the states
        for i in range(len(hid_a)):
            self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
            self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])

    def clear(self):
        self.step = 0


    def get_statistics(self):
        raise NotImplementedError
        done = self.dones
        done[-1] = 1
        flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
        done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
        trajectory_lengths = (done_indices[1:] - done_indices[:-1])
        return trajectory_lengths.float().mean(), self.rewards.mean()
    
    def query_key(self, key: str):
        assert hasattr(self, key), key
        return getattr(self, key)

    # def mini_batch_generator(self, num_mini_batches, num_epochs=8):
    #     batch_size = self.num_envs * self.num_transitions_per_env
    #     mini_batch_size = batch_size // num_mini_batches
    #     indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device)
        
    #     _buffer_dict = {key: getattr(self, key)[:].flatten(0, 1) for key in self.stored_keys}

    #     for epoch in range(num_epochs):
    #         for i in range(num_mini_batches):

    #             start = i*mini_batch_size
    #             end = (i+1)*mini_batch_size
    #             batch_idx = indices[start:end]

    #             _batch_buffer_dict = {key: _buffer_dict[key][batch_idx] for key in self.stored_keys}
    #             yield _batch_buffer_dict

    def mini_batch_generator(self, num_mini_batches, num_epochs=8):
        batch_size = self.num_envs * self.num_transitions_per_env
        mini_batch_size = batch_size // num_mini_batches

#         _buffer_dict = {
#             key: getattr(self, key)[:].flatten(0, 1).contiguous() 
#             for key in self.stored_keys
#         }
#         indices = torch.randperm(batch_size,device=self.device)
# 
#         shuffled_data = {
#             key: _buffer_dict[key][indices] 
#             for key in self.stored_keys
#         }

        indices = torch.randperm(batch_size,device=self.device)
        
        
        shuffled_data = {
            key: getattr(self, key)[:].flatten(0, 1)[indices].contiguous()
            for key in self.stored_keys
        }
        
        for epoch in range(num_epochs):
            
            for i in range(num_mini_batches):
                start = i * mini_batch_size
                end = (i + 1) * mini_batch_size
                _batch_buffer_dict = {
                    key: shuffled_data[key][start:end] 
                    for key in self.stored_keys
                }
                yield _batch_buffer_dict