from collections import deque
import numpy as np
import torch

EPS = 1e-8 

class ReplayBuffer:
    def __init__(
            self, device:torch.device, 
            len_replay_buffer:int, 
            discount_factor:float, 
            gae_coeff:float, 
            n_envs:int,
            n_steps:int, 
            n_update_steps:int) -> None:
        self.device = device
        self.len_replay_buffer = len_replay_buffer
        self.discount_factor = discount_factor
        self.gae_coeff = gae_coeff
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.n_update_steps = n_update_steps

        self.n_steps_per_env = int(self.n_steps/self.n_envs)
        self.n_update_steps_per_env = int(self.n_update_steps/self.n_envs)
        self.len_replay_buffer_per_env = int(self.len_replay_buffer/self.n_envs)

        self.storage = [deque(maxlen=self.len_replay_buffer_per_env) for _ in range(self.n_envs)]

    ################
    # Public Methods
    ################

    def getLen(self):
        return len(self.storage[0])

    def addTransition(self, *args):
        if len(args) == 9:
            states, actions, preferences, log_probs, reward_vecs, cost_vecs, dones, fails, next_states = args
            for env_idx in range(self.n_envs):
                self.storage[env_idx].append([
                    states[env_idx], actions[env_idx], preferences[env_idx], log_probs[env_idx], 
                    reward_vecs[env_idx], cost_vecs[env_idx], dones[env_idx], fails[env_idx], 
                    next_states[env_idx],
                ])
        elif len(args) == 8:
            states, actions, preferences, log_probs, reward_vecs, dones, fails, next_states = args
            for env_idx in range(self.n_envs):
                self.storage[env_idx].append([
                    states[env_idx], actions[env_idx], preferences[env_idx], log_probs[env_idx], 
                    reward_vecs[env_idx], dones[env_idx], fails[env_idx], 
                    next_states[env_idx],
                ])
        else:
            raise ValueError("Wrong number of arguments!")

    def getBatches(self, obs_rms, reward_rms, actor, reward_critic, cost_critic=None):
        state_len = self.getLen()
        n_latest_steps = min(state_len, self.n_steps_per_env)
        n_update_steps = min(state_len, self.n_update_steps_per_env)

        # process the latest trajectories
        states_list, actions_list, preferences_list, reward_targets_list, cost_targets_list \
            = self._processBatches(obs_rms, reward_rms, actor, reward_critic, cost_critic, n_latest_steps, n_update_steps, is_latest=True)

        # process the rest trajectories
        if n_update_steps > n_latest_steps:
            temp_states_list, temp_actions_list, temp_preferences_list, temp_reward_targets_list, temp_cost_targets_list \
                = self._processBatches(obs_rms, reward_rms, actor, reward_critic, cost_critic, n_latest_steps, n_update_steps, is_latest=False)

            states_list = np.concatenate([states_list, temp_states_list], axis=0)
            actions_list = np.concatenate([actions_list, temp_actions_list], axis=0)
            preferences_list = np.concatenate([preferences_list, temp_preferences_list], axis=0)
            reward_targets_list = np.concatenate([reward_targets_list, temp_reward_targets_list], axis=0)
            if cost_critic is not None:
                cost_targets_list = np.concatenate([cost_targets_list, temp_cost_targets_list], axis=0)

        # convert to tensor and return
        states_tensor = torch.tensor(states_list, device=self.device, dtype=torch.float32)
        actions_tensor = torch.tensor(actions_list, device=self.device, dtype=torch.float32)
        preferences_tensor = torch.tensor(preferences_list, device=self.device, dtype=torch.float32)
        reward_targets_tensor = torch.tensor(reward_targets_list, device=self.device, dtype=torch.float32)
        if cost_critic is not None:
            cost_targets_tensor = torch.tensor(cost_targets_list, device=self.device, dtype=torch.float32)
            return states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor, cost_targets_tensor
        else:
            return states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor 

    #################
    # Private Methods
    #################

    def _processBatches(self, obs_rms, reward_rms, actor, reward_critic, cost_critic, n_latest_steps, n_update_steps, is_latest):
        states_list = []
        actions_list = []
        preferences_list = []
        reward_targets_list = []
        cost_targets_list = []

        for env_idx in range(self.n_envs):
            if is_latest:
                env_trajs = list(self.storage[env_idx])[-n_latest_steps:]
            else:
                start_idx = np.random.randint(0, len(self.storage[env_idx]) - n_update_steps + 1)
                end_idx = start_idx + n_update_steps - n_latest_steps
                env_trajs = list(self.storage[env_idx])[start_idx:end_idx]
            if cost_critic is None:
                states = np.array([traj[0] for traj in env_trajs])
                actions = np.array([traj[1] for traj in env_trajs])
                preferences = np.array([traj[2] for traj in env_trajs])
                log_probs = np.array([traj[3] for traj in env_trajs])
                reward_vecs = np.array([traj[4] for traj in env_trajs])
                dones = np.array([traj[5] for traj in env_trajs])
                fails = np.array([traj[6] for traj in env_trajs])
                next_states = np.array([traj[7] for traj in env_trajs])
            else:
                states = np.array([traj[0] for traj in env_trajs])
                actions = np.array([traj[1] for traj in env_trajs])
                preferences = np.array([traj[2] for traj in env_trajs])
                log_probs = np.array([traj[3] for traj in env_trajs])
                reward_vecs = np.array([traj[4] for traj in env_trajs])
                cost_vecs = np.array([traj[5] for traj in env_trajs])
                dones = np.array([traj[6] for traj in env_trajs])
                fails = np.array([traj[7] for traj in env_trajs])
                next_states = np.array([traj[8] for traj in env_trajs])

            # normalize 
            states = obs_rms.normalize(states)
            next_states = obs_rms.normalize(next_states)
            reward_vecs = reward_rms.normalize(reward_vecs)
            if cost_critic is not None:
                cost_vecs = (1.0 - fails.reshape(-1, 1))*cost_vecs + fails.reshape(-1, 1)*cost_vecs/(1.0 - self.discount_factor)

            # convert to tensor
            states_tensor = torch.tensor(states, device=self.device, dtype=torch.float32)
            next_states_tensor = torch.tensor(next_states, device=self.device, dtype=torch.float32)
            actions_tensor = torch.tensor(actions, device=self.device, dtype=torch.float32)
            mu_log_probs_tensor = torch.tensor(log_probs, device=self.device, dtype=torch.float32)
            preferences_tensor = torch.tensor(preferences, device=self.device, dtype=torch.float32)

            # for rho
            means_tensor, _, stds_tensor = actor(states_tensor, preferences_tensor)
            dists_tensor = torch.distributions.Normal(means_tensor, stds_tensor)
            old_log_probs_tensor = dists_tensor.log_prob(actions_tensor).sum(dim=-1)
            rhos_tensor = torch.clamp(torch.exp(old_log_probs_tensor - mu_log_probs_tensor), 0.0, 1.0)
            rhos = rhos_tensor.detach().cpu().numpy()

            # get values
            epsilons_tensor = torch.normal(mean=torch.zeros_like(actions_tensor), std=torch.ones_like(actions_tensor))
            actor.updateActionDist(next_states_tensor, preferences_tensor, epsilons_tensor)
            next_actions_tensor = actor.sample(deterministic=False)[0]
            next_reward_values = reward_critic(next_states_tensor, next_actions_tensor, preferences_tensor).detach().cpu().numpy()
            reward_values = reward_critic(states_tensor, actions_tensor, preferences_tensor).detach().cpu().numpy()
            if cost_critic is not None:
                next_cost_values = cost_critic(next_states_tensor, next_actions_tensor, preferences_tensor).detach().cpu().numpy()
                cost_values = cost_critic(states_tensor, actions_tensor, preferences_tensor).detach().cpu().numpy()

            # get targets
            reward_delta = np.zeros_like(reward_vecs[0]) # n_rewards
            reward_targets = np.zeros_like(reward_vecs) # n_steps x n_rewards
            if cost_critic is not None:
                cost_delta = np.zeros_like(cost_vecs[0]) # n_costs
                cost_targets = np.zeros_like(cost_vecs) # n_steps x n_costs
            for t in reversed(range(len(reward_targets))):
                reward_targets[t] = reward_vecs[t] + self.discount_factor*(1.0 - fails[t])*next_reward_values[t] \
                                + self.discount_factor*(1.0 - dones[t])*reward_delta
                reward_delta = self.gae_coeff*rhos[t]*(reward_targets[t] - reward_values[t])
                if cost_critic is not None:
                    cost_targets[t] = cost_vecs[t] + self.discount_factor*(1.0 - fails[t])*next_cost_values[t] \
                                    + self.discount_factor*(1.0 - dones[t])*cost_delta
                    cost_delta = self.gae_coeff*rhos[t]*(cost_targets[t] - cost_values[t])

            # append
            states_list.append(states)
            actions_list.append(actions)
            preferences_list.append(preferences)
            reward_targets_list.append(reward_targets)
            if cost_critic is not None:
                cost_targets_list.append(cost_targets)
            else:
                cost_targets_list.append(np.zeros_like(reward_targets))

        return np.concatenate(states_list, axis=0), np.concatenate(actions_list, axis=0), np.concatenate(preferences_list, axis=0), \
                np.concatenate(reward_targets_list, axis=0), np.concatenate(cost_targets_list, axis=0)