import numpy as np
import random
import torch

EPS = 1e-8 

class ReplayBuffer:
    def __init__(
            self, device:torch.device, 
            len_replay_buffer:int, 
            discount_factor:float,
            batch_size:int, 
            n_envs:int,
            max_episode_len:int,
            state_dim:int, 
            action_dim:int, 
            preference_dim:int, 
            reward_dim:int, 
            cost_dim:int) -> None:

        self.device = device
        self.len_replay_buffer = len_replay_buffer
        self.discount_factor = discount_factor
        self.batch_size = batch_size
        self.n_envs = n_envs
        self.max_episode_len = max_episode_len
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.preference_dim = preference_dim
        self.reward_dim = reward_dim
        self.cost_dim = cost_dim

        self.len_replay_buffer_per_env = int(self.len_replay_buffer/self.n_envs)
        self.is_full = False
        self.pos = 0

        self.states = np.zeros((self.len_replay_buffer_per_env, self.n_envs, self.state_dim))
        self.actions = np.zeros((self.len_replay_buffer_per_env, self.n_envs, self.action_dim))
        self.preferences = np.zeros((self.len_replay_buffer_per_env, self.n_envs, self.preference_dim))
        self.rewards = np.zeros((self.len_replay_buffer_per_env, self.n_envs, self.reward_dim))
        if self.cost_dim > 0:
            self.costs = np.zeros((self.len_replay_buffer_per_env, self.n_envs, self.cost_dim))
        self.dones = np.zeros((self.len_replay_buffer_per_env, self.n_envs))
        self.fails = np.zeros((self.len_replay_buffer_per_env, self.n_envs))
        self.next_states = np.zeros((self.len_replay_buffer_per_env, self.n_envs, self.state_dim))

    ################
    # Public Methods
    ################

    def getLen(self):
        if self.is_full:
            return self.len_replay_buffer
        else:
            return self.pos*self.n_envs

    def addTransition(self, *args):
        if self.cost_dim == 0:
            states, actions, preferences, rewards, dones, fails, next_states = args
            assert len(states) == self.n_envs
            self.states[self.pos, :, :] = states
            self.actions[self.pos, :, :] = actions
            self.preferences[self.pos, :, :] = preferences
            self.rewards[self.pos, :, :] = rewards
            self.dones[self.pos, :] = dones
            self.fails[self.pos, :] = fails
            self.next_states[self.pos, :, :] = next_states
        else:
            states, actions, preferences, rewards, costs, dones, fails, next_states = args
            assert len(states) == self.n_envs
            self.states[self.pos, :, :] = states
            self.actions[self.pos, :, :] = actions
            self.preferences[self.pos, :, :] = preferences
            self.rewards[self.pos, :, :] = rewards
            self.costs[self.pos, :, :] = costs
            self.dones[self.pos, :] = dones
            self.fails[self.pos, :] = fails
            self.next_states[self.pos, :, :] = next_states
        self.pos += 1
        if self.pos == self.len_replay_buffer_per_env:
            self.is_full = True
            self.pos = 0
    
    @torch.no_grad()
    def getBatches(self, obs_rms, reward_rms):
        # len_unit = self.max_episode_len // self.n_envs
        len_unit = 1000 // self.n_envs
        if self.is_full:
            sampled_idx = np.random.randint(self.len_replay_buffer_per_env)
            start_idx = (sampled_idx // len_unit) * len_unit
            end_idx = start_idx + len_unit
        else:
            if ((self.pos - 1) % len_unit + 1)*self.n_envs < self.batch_size:
                sampled_idx = np.random.randint(self.pos - (self.pos % len_unit))
                start_idx = (sampled_idx // len_unit) * len_unit
                end_idx = start_idx + len_unit
            else:
                sampled_idx = np.random.randint(self.pos)
                start_idx = (sampled_idx // len_unit) * len_unit
                end_idx = np.clip(start_idx + len_unit, 0, self.pos)
        states = self.states[start_idx:end_idx].reshape(-1, self.state_dim)
        actions = self.actions[start_idx:end_idx].reshape(-1, self.action_dim)
        preferences = self.preferences[start_idx:end_idx].reshape(-1, self.preference_dim)
        rewards = self.rewards[start_idx:end_idx].reshape(-1, self.reward_dim)
        dones = self.dones[start_idx:end_idx].reshape(-1)
        fails = self.fails[start_idx:end_idx].reshape(-1)
        next_states = self.next_states[start_idx:end_idx].reshape(-1, self.state_dim)
        if self.cost_dim > 0:
            costs = self.costs[start_idx:end_idx].reshape(-1, self.cost_dim)
        batch_inds = random.sample(range(len(states)), self.batch_size)

        states_tensor = torch.tensor(obs_rms.normalize(states[batch_inds]), device=self.device, dtype=torch.float32)
        next_states_tensor = torch.tensor(obs_rms.normalize(next_states[batch_inds]), device=self.device, dtype=torch.float32)
        rewards_tensor = torch.tensor(reward_rms.normalize(rewards[batch_inds]), device=self.device, dtype=torch.float32)
        actions_tensor = torch.tensor(actions[batch_inds], device=self.device, dtype=torch.float32)
        dones_tensor = torch.tensor(dones[batch_inds], device=self.device, dtype=torch.float32)
        fails_tensor = torch.tensor(fails[batch_inds], device=self.device, dtype=torch.float32)
        preferences_tensor = torch.tensor(preferences[batch_inds], device=self.device, dtype=torch.float32)

        if self.cost_dim > 0:
            costs_tensor = torch.tensor(costs[batch_inds], device=self.device, dtype=torch.float32)
            costs_tensor = (1.0 - fails_tensor.view(-1, 1))*costs_tensor + fails_tensor.view(-1, 1)*costs_tensor/(1.0 - self.discount_factor)
            return states_tensor, actions_tensor, rewards_tensor, costs_tensor, preferences_tensor, dones_tensor, fails_tensor, next_states_tensor
        else:
            return states_tensor, actions_tensor, rewards_tensor, preferences_tensor, dones_tensor, fails_tensor, next_states_tensor
        