from collections import deque
from copy import deepcopy
import numpy as np
import random
import ctypes
import torch
import os

EPS = 1e-8 

def ctypeArrayConvert(arr):
    arr = np.ravel(arr)
    return (ctypes.c_double * len(arr))(*arr)

class ReplayBuffer:
    def __init__(
            self, device:torch.device, 
            len_replay_buffer:int, 
            action_bound_min:np.ndarray, 
            action_bound_max:np.ndarray, 
            discount_factor:float, 
            gae_coeff:float, 
            n_envs:int,
            n_steps:int, 
            n_update_steps:int, 
            n_target_quantiles:int) -> None:

        self.device = device
        self.len_replay_buffer = len_replay_buffer
        self.discount_factor = discount_factor
        self.gae_coeff = gae_coeff
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.n_update_steps = n_update_steps
        self.n_target_quantiles = n_target_quantiles
        self.action_bound_min = -torch.ones(
            len(action_bound_min), device=device, dtype=torch.float32)
        self.action_bound_max = torch.ones(
            len(action_bound_max), device=device, dtype=torch.float32)
        self.n_steps_per_env = int(self.n_steps/self.n_envs)
        self.n_update_steps_per_env = int(self.n_update_steps/self.n_envs)
        self.len_replay_buffer_per_env = int(self.len_replay_buffer/self.n_envs)
        self.storage = [deque(maxlen=self.len_replay_buffer_per_env) for _ in range(self.n_envs)]

        # projection operator
        self._lib = ctypes.cdll.LoadLibrary(f'{os.path.dirname(os.path.abspath(__file__))}/cpp_modules/main.so')
        self._lib.projection.restype = None

    ################
    # Public Methods
    ################

    def getLen(self):
        return np.sum([len(self.storage[i]) for i in range(self.n_envs)])

    def addTransition(self, states, actions, preferences, log_probs, reward_vecs, cost_vecs, dones, fails, next_states):
        for env_idx in range(self.n_envs):
            self.storage[env_idx].append([
                states[env_idx], actions[env_idx], preferences[env_idx], log_probs[env_idx], 
                reward_vecs[env_idx], cost_vecs[env_idx], dones[env_idx], fails[env_idx], 
                next_states[env_idx],
            ])

    @torch.no_grad()
    def getBatches(self, obs_rms, reward_rms, actor, reward_critic, cost_critic):
        state_len = len(self.storage[0])

        # process the latest trajectories
        con_vals = []
        states_list, actions_list, preferences_list, reward_targets_list, cost_targets_list, costs_list \
            = self._processBatches(obs_rms, reward_rms, actor, reward_critic, cost_critic, state_len - self.n_steps_per_env, state_len)
        con_vals.append(np.mean(costs_list, axis=0)/(1.0 - self.discount_factor)) # (cost_dim,)

        # process the old trajectories
        n_total_trajs = (state_len // self.n_steps_per_env) - 1
        n_update_trajs = (self.n_update_steps_per_env // self.n_steps_per_env) - 1
        sampled_traj_indices = random.sample(range(n_total_trajs), n_update_trajs)
        for traj_idx in sampled_traj_indices:
            start_idx = traj_idx*self.n_steps_per_env
            end_idx = start_idx + self.n_steps_per_env
            temp_states_list, temp_actions_list, temp_preferences_list, temp_reward_targets_list, temp_cost_targets_list, costs_list \
                = self._processBatches(obs_rms, reward_rms, actor, reward_critic, cost_critic, start_idx, end_idx)
            
            # concatenate
            states_list = np.concatenate([states_list, temp_states_list], axis=0)
            actions_list = np.concatenate([actions_list, temp_actions_list], axis=0)
            preferences_list = np.concatenate([preferences_list, temp_preferences_list], axis=0)
            reward_targets_list = np.concatenate([reward_targets_list, temp_reward_targets_list], axis=0)
            cost_targets_list = np.concatenate([cost_targets_list, temp_cost_targets_list], axis=0)
            con_vals.append(np.mean(costs_list, axis=0)/(1.0 - self.discount_factor)) # (cost_dim,)

        # convert to tensor and return
        states_tensor = torch.tensor(states_list, device=self.device, dtype=torch.float32)
        actions_tensor = torch.tensor(actions_list, device=self.device, dtype=torch.float32)
        preferences_tensor = torch.tensor(preferences_list, device=self.device, dtype=torch.float32)
        reward_targets_tensor = torch.tensor(reward_targets_list, device=self.device, dtype=torch.float32)
        cost_targets_tensor = torch.tensor(cost_targets_list, device=self.device, dtype=torch.float32)
        con_vals_tensor = torch.tensor(np.array(con_vals), device=self.device, dtype=torch.float32)
        return states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor, cost_targets_tensor, con_vals_tensor
        
    #################
    # Private Methods
    #################

    def _projection(self, quantiles1:np.ndarray, weight1:float, quantiles2:np.ndarray, weight2:float) -> np.ndarray:
        n_quantiles1 = len(quantiles1)
        n_quantiles2 = len(quantiles2)
        assert n_quantiles1 <= n_quantiles2
        n_quantiles3 = self.n_target_quantiles
        cpp_quantiles1 = ctypeArrayConvert(quantiles1)
        cpp_quantiles2 = ctypeArrayConvert(quantiles2)
        cpp_new_quantiles = ctypeArrayConvert(np.zeros(n_quantiles3))
        self._lib.projection.argtypes = [
            ctypes.c_int, ctypes.c_double, ctypes.POINTER(ctypes.c_double*n_quantiles1), ctypes.c_int, ctypes.c_double, 
            ctypes.POINTER(ctypes.c_double*n_quantiles2), ctypes.c_int, ctypes.POINTER(ctypes.c_double*n_quantiles3)
        ]
        self._lib.projection(n_quantiles1, weight1, cpp_quantiles1, n_quantiles2, 
                             weight2, cpp_quantiles2, n_quantiles3, cpp_new_quantiles)
        new_quantiles = np.array(cpp_new_quantiles)
        return new_quantiles

    def _getQuantileTargets(self, rewards:np.ndarray, dones:np.ndarray, fails:np.ndarray, 
                            rhos:np.ndarray, next_quantiles:np.ndarray) -> np.ndarray:
        """
        inputs:
            rewards: (batch_size,)
            dones: (batch_size,)
            fails: (batch_size,)
            rhos: (batch_size,)
            next_quantiles: (batch_size, n_critics*n_quantiles)
        outputs:
            target_quantiles: (batch_size, n_target_quantiles)
        """
        target_quantiles = np.zeros((next_quantiles.shape[0], self.n_target_quantiles))
        gae_target = rewards[-1] + self.discount_factor*(1.0 - fails[-1])*next_quantiles[-1] # (n_critics*n_quantiles,)
        gae_weight = self.gae_coeff
        for t in reversed(range(len(target_quantiles))):
            target = rewards[t] + self.discount_factor*(1.0 - fails[t])*next_quantiles[t] # (n_critics*n_quantiles,)
            target = self._projection(target, 1.0 - self.gae_coeff, gae_target, gae_weight) # (n_target_quantiles,)
            target_quantiles[t, :] = target[:]
            if t != 0:
                if self.gae_coeff != 1.0:
                    gae_weight = self.gae_coeff*rhos[t]*(1.0 - dones[t-1])*(1.0 - self.gae_coeff + gae_weight)
                gae_target = rewards[t-1] + self.discount_factor*(1.0 - fails[t-1])*target # (n_target_quantiles,)
        return target_quantiles

    def _processBatches(self, obs_rms, reward_rms, actor, reward_critic, cost_critic, start_idx, end_idx):
        states_list = []
        actions_list = []
        preferences_list = []
        reward_targets_list = []
        cost_targets_list = []
        costs_list = []

        for env_idx in range(self.n_envs):
            env_trajs = list(self.storage[env_idx])[start_idx:end_idx]
            states = np.array([traj[0] for traj in env_trajs])
            actions = np.array([traj[1] for traj in env_trajs])
            preferences = np.array([traj[2] for traj in env_trajs])
            log_probs = np.array([traj[3] for traj in env_trajs])
            reward_vecs = np.array([traj[4] for traj in env_trajs])
            cost_vecs = np.array([traj[5] for traj in env_trajs])
            dones = np.array([traj[6] for traj in env_trajs])
            fails = np.array([traj[7] for traj in env_trajs])
            next_states = np.array([traj[8] for traj in env_trajs])

            # normalize 
            states = obs_rms.normalize(states)
            next_states = obs_rms.normalize(next_states)
            reward_vecs = reward_rms.normalize(reward_vecs)
            cost_vecs = (1.0 - fails.reshape(-1, 1))*cost_vecs + fails.reshape(-1, 1)*cost_vecs/(1.0 - self.discount_factor)

            # convert to tensor
            states_tensor = torch.tensor(states, device=self.device, dtype=torch.float32)
            next_states_tensor = torch.tensor(next_states, device=self.device, dtype=torch.float32)
            actions_tensor = torch.tensor(actions, device=self.device, dtype=torch.float32)
            mu_log_probs_tensor = torch.tensor(log_probs, device=self.device, dtype=torch.float32)
            preferences_tensor = torch.tensor(preferences, device=self.device, dtype=torch.float32)

            # get dimensions
            batch_size = states_tensor.shape[0]
            reward_dim = reward_vecs.shape[1]
            cost_dim = cost_vecs.shape[1]

            # for rho
            means_tensor, _, stds_tensor = actor(states_tensor, preferences_tensor)
            dists_tensor = torch.distributions.Normal(means_tensor, stds_tensor)
            old_log_probs_tensor = dists_tensor.log_prob(actions_tensor).sum(dim=-1)
            rhos_tensor = torch.clamp(torch.exp(old_log_probs_tensor - mu_log_probs_tensor), 0.0, 100.0)
            rhos = rhos_tensor.detach().cpu().numpy() # (batch_size,)

            # get next actions
            epsilons_tensor = torch.randn_like(actions_tensor)
            actor.updateActionDist(next_states_tensor, preferences_tensor, epsilons_tensor)
            next_actions_tensor = torch.clamp(actor.sample(deterministic=False)[0], self.action_bound_min, self.action_bound_max)

            # get reward targets
            next_reward_quantiles_tensor = reward_critic(next_states_tensor, next_actions_tensor, preferences_tensor).view(batch_size, reward_dim, -1)
            next_reward_quantiles = torch.sort(next_reward_quantiles_tensor, dim=-1)[0].detach().cpu().numpy() # (batch_size, reward_dim, n_critics*n_quantiles)
            reward_targets = []
            for reward_idx in range(reward_dim):
                reward_targets.append(self._getQuantileTargets(reward_vecs[:, reward_idx], dones, fails, rhos, next_reward_quantiles[:, reward_idx, :]))
            reward_targets = np.stack(reward_targets, axis=1) # (batch_size, reward_dim, n_target_quantiles)

            # get cost targets
            next_cost_quantiles_tensor = cost_critic(next_states_tensor, next_actions_tensor, preferences_tensor).view(batch_size, cost_dim, -1)
            next_cost_quantiles = torch.sort(next_cost_quantiles_tensor, dim=-1)[0].detach().cpu().numpy() # (batch_size, cost_dim, n_critics*n_quantiles)
            cost_targets = []
            for cost_idx in range(cost_dim):
                cost_targets.append(self._getQuantileTargets(cost_vecs[:, cost_idx], dones, fails, rhos, next_cost_quantiles[:, cost_idx, :]))
            cost_targets = np.stack(cost_targets, axis=1) # (batch_size, cost_dim, n_target_quantiles)

            # append
            states_list.append(states)
            actions_list.append(actions)
            preferences_list.append(preferences)
            reward_targets_list.append(reward_targets)
            cost_targets_list.append(cost_targets)
            costs_list.append(cost_vecs)

        states_list = np.concatenate(states_list, axis=0)
        actions_list = np.concatenate(actions_list, axis=0)
        preferences_list = np.concatenate(preferences_list, axis=0)
        reward_targets_list = np.concatenate(reward_targets_list, axis=0)
        cost_targets_list = np.concatenate(cost_targets_list, axis=0)
        costs_list = np.concatenate(costs_list, axis=0)

        return states_list, actions_list, preferences_list, reward_targets_list, cost_targets_list, costs_list