import numpy as np
import torch
import utils
import os

class Buffer(object):
    def __init__(self, obs_shape, action_shape, capacity, device):
        self.capacity = capacity
        self.device = device

        # the proprioceptive obs is stored as float32, pixels obs as uint8
        obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8

        self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones_no_max = np.empty((capacity, 1), dtype=np.float32)
        
        
        self.obs_dim = obs_shape[0]

        self.idx = 0
        self.last_save = 0
        self.full = False

    def __len__(self):
        return self.capacity if self.full else self.idx

    def add(self, obs, action, reward, next_obs, done, done_no_max):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.not_dones_no_max[self.idx], not done_no_max)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0
    
    def save(self, save_dir):
        if self.idx == self.last_save:
            return
        path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
        payload = [
            self.obses[self.last_save:self.idx],
            self.next_obses[self.last_save:self.idx],
            self.actions[self.last_save:self.idx],
            self.rewards[self.last_save:self.idx],
            self.not_dones[self.last_save:self.idx]
        ]
        self.last_save = self.idx
        torch.save(payload, path)

    def load(self, save_dir):
        chunks = os.listdir(save_dir)
        chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
        for chunk in chucks:
            start, end = [int(x) for x in chunk.split('.')[0].split('_')]
            path = os.path.join(save_dir, chunk)
            payload = torch.load(path)
            assert self.idx == start
            self.obses[start:end] = payload[0][:, :self.obs_dim]
            self.next_obses[start:end] = payload[1][:, :self.obs_dim]
            self.actions[start:end] = payload[2]
            self.rewards[start:end] = payload[3]
            self.not_dones[start:end] = payload[4]
            self.idx = end

    # shuffle and split replay buffer into training and validation buffer
    def split(self, buffer, valid=1e4):
        # idxs = np.random.randint(0,
        #                          self.capacity if self.full else self.idx,
        #                          size=valid)
        self.idx -= valid
        idx = self.idx
        for i in range(int(valid)):
            buffer.add(self.obses[idx+i], self.actions[idx+i], self.rewards[idx+i], self.next_obses[idx+i], 1 - self.not_dones[idx+i], 1 - self.not_dones_no_max[idx+i])
        

class ReplayBuffer(Buffer):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, action_shape, capacity, device):
        super().__init__(obs_shape, action_shape, capacity, device)

    def sample(self, batch_size):
        idxs = np.random.randint(0,
                                 self.capacity if self.full else self.idx,
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs],
                                           device=self.device)

        return obses, actions, rewards, next_obses, not_dones, not_dones_no_max

# a replay buffer that return obses in delta forms st' = [st-st-1, st-1 - st-2, st-2, st-3]， concatenated with original form
class DeltaBuffer(Buffer):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, action_shape, capacity, device, stack_num):
        super().__init__(obs_shape, action_shape, capacity, device)
        self.stack_num = stack_num + 1

    def sample(self, batch_size):
        idxs = np.random.randint(0,
                                 self.capacity if self.full else self.idx,
                                 size=batch_size)

        for i in range(batch_size):   
            for j in range(self.stack_num):
                if not self.not_dones[idxs[i]+j] or (idxs[i]+j) == (self.idx-1): # then sample before, assume a trajectory won't be less than 4 in length
                    idxs[i] -= (self.stack_num - j)
                    break 
        
        last = idxs + self.stack_num - 1
        ori_obses = torch.as_tensor(self.obses[last], device=self.device).float()
        actions = torch.as_tensor(self.actions[last], device=self.device)
        rewards = torch.as_tensor(self.rewards[last], device=self.device)
        ori_next_obses = torch.as_tensor(self.next_obses[last],
                                      device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[last], device=self.device)
        not_dones_no_max = torch.as_tensor(self.not_dones_no_max[last],
                                           device=self.device)

        # Get delta_obses and delta_next_obses
        obses = np.expand_dims(self.obses[idxs], 1)
        next_obses = np.expand_dims(self.next_obses[idxs], 1)
        for i in range(1, self.stack_num):
             obses = np.concatenate((obses, np.expand_dims(self.obses[idxs+i], 1)), axis=1)
             next_obses = np.concatenate((next_obses,  np.expand_dims(self.next_obses[idxs+i], 1)), axis=1)
        obses = torch.as_tensor(obses, device=self.device).float()   #(batch, stack, obs_shape)
        next_obses = torch.as_tensor(next_obses, device=self.device).float()
        delta_obses = obses[:, 1:, :] - obses[:, :-1, :]
        delta_next_obses = next_obses[:, 1:, :] - next_obses[:, :-1, :]
        delta_obses, delta_next_obses = delta_obses.reshape(next_obses.shape[0], -1), delta_next_obses.reshape(next_obses.shape[0], -1)

        return torch.cat((ori_obses, delta_obses), -1) , actions, rewards, torch.cat((ori_next_obses, delta_next_obses), -1) , not_dones, not_dones_no_max 