from collections import deque
from copy import deepcopy
import numpy as np
import ctypes
import torch
import os

EPS = 1e-8 

def ctypeArrayConvert(arr):
    arr = np.ravel(arr)
    return (ctypes.c_double * len(arr))(*arr)

class ReplayBuffer:
    def __init__(
            self, device:torch.device, 
            len_replay_buffer:int, 
            discount_factor:float, 
            gae_coeff:float, 
            n_envs:int,
            n_steps:int, 
            n_update_steps:int, 
            n_target_quantiles:int) -> None:
        self.device = device
        self.len_replay_buffer = len_replay_buffer
        self.discount_factor = discount_factor
        self.gae_coeff = gae_coeff
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.n_update_steps = n_update_steps
        self.n_target_quantiles = n_target_quantiles

        self.n_steps_per_env = int(self.n_steps/self.n_envs)
        self.n_update_steps_per_env = int(self.n_update_steps/self.n_envs)
        self.len_replay_buffer_per_env = int(self.len_replay_buffer/self.n_envs)

        self.storage = [deque(maxlen=self.len_replay_buffer_per_env) for _ in range(self.n_envs)]

        # projection operator
        self._lib = ctypes.cdll.LoadLibrary(f'{os.path.dirname(os.path.abspath(__file__))}/cpp_modules/main.so')
        self._lib.projection.restype = None

    ################
    # Public Methods
    ################

    def getLen(self):
        return len(self.storage[0])

    def addTransition(self, *args):
        if len(args) == 9:
            states, actions, preferences, log_probs, reward_vecs, cost_vecs, dones, fails, next_states = args
            for env_idx in range(self.n_envs):
                self.storage[env_idx].append([
                    states[env_idx], actions[env_idx], preferences[env_idx], log_probs[env_idx], 
                    reward_vecs[env_idx], cost_vecs[env_idx], dones[env_idx], fails[env_idx], 
                    next_states[env_idx],
                ])
        elif len(args) == 8:
            states, actions, preferences, log_probs, reward_vecs, dones, fails, next_states = args
            for env_idx in range(self.n_envs):
                self.storage[env_idx].append([
                    states[env_idx], actions[env_idx], preferences[env_idx], log_probs[env_idx], 
                    reward_vecs[env_idx], dones[env_idx], fails[env_idx], 
                    next_states[env_idx],
                ])
        else:
            raise ValueError("Wrong number of arguments!")

    def getBatches(self, obs_rms, reward_rms, actor, reward_critic, cost_critic=None):
        state_len = self.getLen()
        n_latest_steps = min(state_len, self.n_steps_per_env)
        n_update_steps = min(state_len, self.n_update_steps_per_env)

        # process the latest trajectories
        states_list, actions_list, preferences_list, reward_targets_list, cost_targets_list \
            = self._processBatches(obs_rms, reward_rms, actor, reward_critic, cost_critic, n_latest_steps, n_update_steps, is_latest=True)

        # process the rest trajectories
        if n_update_steps > n_latest_steps:
            temp_states_list, temp_actions_list, temp_preferences_list, temp_reward_targets_list, temp_cost_targets_list \
                = self._processBatches(obs_rms, reward_rms, actor, reward_critic, cost_critic, n_latest_steps, n_update_steps, is_latest=False)

            states_list = np.concatenate([states_list, temp_states_list], axis=0)
            actions_list = np.concatenate([actions_list, temp_actions_list], axis=0)
            preferences_list = np.concatenate([preferences_list, temp_preferences_list], axis=0)
            reward_targets_list = np.concatenate([reward_targets_list, temp_reward_targets_list], axis=0)
            if cost_critic is not None:
                cost_targets_list = np.concatenate([cost_targets_list, temp_cost_targets_list], axis=0)

        # convert to tensor and return
        states_tensor = torch.tensor(states_list, device=self.device, dtype=torch.float32)
        actions_tensor = torch.tensor(actions_list, device=self.device, dtype=torch.float32)
        preferences_tensor = torch.tensor(preferences_list, device=self.device, dtype=torch.float32)
        reward_targets_tensor = torch.tensor(reward_targets_list, device=self.device, dtype=torch.float32)
        if cost_critic is not None:
            cost_targets_tensor = torch.tensor(cost_targets_list, device=self.device, dtype=torch.float32)
            return states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor, cost_targets_tensor
        else:
            return states_tensor, actions_tensor, preferences_tensor, reward_targets_tensor 

    #################
    # Private Methods
    #################

    def _projection(self, quantiles1:np.ndarray, weight1:float, quantiles2:np.ndarray, weight2:float) -> np.ndarray:
        n_quantiles1 = len(quantiles1)
        n_quantiles2 = len(quantiles2)
        assert n_quantiles1 <= n_quantiles2
        n_quantiles3 = self.n_target_quantiles
        cpp_quantiles1 = ctypeArrayConvert(quantiles1)
        cpp_quantiles2 = ctypeArrayConvert(quantiles2)
        cpp_new_quantiles = ctypeArrayConvert(np.zeros(n_quantiles3))
        self._lib.projection.argtypes = [
            ctypes.c_int, ctypes.c_double, ctypes.POINTER(ctypes.c_double*n_quantiles1), ctypes.c_int, ctypes.c_double, 
            ctypes.POINTER(ctypes.c_double*n_quantiles2), ctypes.c_int, ctypes.POINTER(ctypes.c_double*n_quantiles3)
        ]
        self._lib.projection(n_quantiles1, weight1, cpp_quantiles1, n_quantiles2, 
                             weight2, cpp_quantiles2, n_quantiles3, cpp_new_quantiles)
        new_quantiles = np.array(cpp_new_quantiles)
        return new_quantiles

    def _getQuantileTargets(self, rewards:np.ndarray, dones:np.ndarray, fails:np.ndarray, 
                            rhos:np.ndarray, next_quantiles:np.ndarray) -> np.ndarray:
        """
        inputs:
            rewards: (batch_size,)
            dones: (batch_size,)
            fails: (batch_size,)
            rhos: (batch_size,)
            next_quantiles: (batch_size, n_critics*n_quantiles)
        outputs:
            target_quantiles: (batch_size, n_target_quantiles)
        """
        target_quantiles = np.zeros((next_quantiles.shape[0], self.n_target_quantiles))
        gae_target = rewards[-1] + self.discount_factor*(1.0 - fails[-1])*next_quantiles[-1] # (n_critics*n_quantiles,)
        gae_weight = self.gae_coeff
        for t in reversed(range(len(target_quantiles))):
            target = rewards[t] + self.discount_factor*(1.0 - fails[t])*next_quantiles[t] # (n_critics*n_quantiles,)
            target = self._projection(target, 1.0 - self.gae_coeff, gae_target, gae_weight) # (n_target_quantiles,)
            target_quantiles[t, :] = target[:]
            if t != 0:
                if self.gae_coeff != 1.0:
                    gae_weight = self.gae_coeff*rhos[t]*(1.0 - dones[t-1])*(1.0 - self.gae_coeff + gae_weight)
                gae_target = rewards[t-1] + self.discount_factor*(1.0 - fails[t-1])*target # (n_target_quantiles,)
        return target_quantiles

    def _processBatches(self, obs_rms, reward_rms, actor, reward_critic, cost_critic, n_latest_steps, n_update_steps, is_latest):
        states_list = []
        actions_list = []
        preferences_list = []
        reward_targets_list = []
        cost_targets_list = []

        for env_idx in range(self.n_envs):
            if is_latest:
                env_trajs = list(self.storage[env_idx])[-n_latest_steps:]
            else:
                start_idx = np.random.randint(0, len(self.storage[env_idx]) - n_update_steps + 1)
                end_idx = start_idx + n_update_steps - n_latest_steps
                env_trajs = list(self.storage[env_idx])[start_idx:end_idx]
            if cost_critic is None:
                states = np.array([traj[0] for traj in env_trajs])
                actions = np.array([traj[1] for traj in env_trajs])
                preferences = np.array([traj[2] for traj in env_trajs])
                log_probs = np.array([traj[3] for traj in env_trajs])
                reward_vecs = np.array([traj[4] for traj in env_trajs])
                dones = np.array([traj[5] for traj in env_trajs])
                fails = np.array([traj[6] for traj in env_trajs])
                next_states = np.array([traj[7] for traj in env_trajs])
            else:
                states = np.array([traj[0] for traj in env_trajs])
                actions = np.array([traj[1] for traj in env_trajs])
                preferences = np.array([traj[2] for traj in env_trajs])
                log_probs = np.array([traj[3] for traj in env_trajs])
                reward_vecs = np.array([traj[4] for traj in env_trajs])
                cost_vecs = np.array([traj[5] for traj in env_trajs])
                dones = np.array([traj[6] for traj in env_trajs])
                fails = np.array([traj[7] for traj in env_trajs])
                next_states = np.array([traj[8] for traj in env_trajs])

            # normalize 
            states = obs_rms.normalize(states)
            next_states = obs_rms.normalize(next_states)
            reward_vecs = reward_rms.normalize(reward_vecs)
            if cost_critic is not None:
                cost_vecs = (1.0 - fails.reshape(-1, 1))*cost_vecs + fails.reshape(-1, 1)*cost_vecs/(1.0 - self.discount_factor)

            # convert to tensor
            states_tensor = torch.tensor(states, device=self.device, dtype=torch.float32)
            next_states_tensor = torch.tensor(next_states, device=self.device, dtype=torch.float32)
            actions_tensor = torch.tensor(actions, device=self.device, dtype=torch.float32)
            mu_log_probs_tensor = torch.tensor(log_probs, device=self.device, dtype=torch.float32)
            preferences_tensor = torch.tensor(preferences, device=self.device, dtype=torch.float32)

            # get dimensions
            batch_size = states_tensor.shape[0]
            reward_dim = reward_vecs.shape[1]
            if cost_critic is not None:
                cost_dim = cost_vecs.shape[1]

            # for rho
            means_tensor, _, stds_tensor = actor(states_tensor, preferences_tensor)
            dists_tensor = torch.distributions.Normal(means_tensor, stds_tensor)
            old_log_probs_tensor = dists_tensor.log_prob(actions_tensor).sum(dim=-1)
            rhos_tensor = torch.clamp(torch.exp(old_log_probs_tensor - mu_log_probs_tensor), 0.0, 100.0)
            rhos = rhos_tensor.detach().cpu().numpy() # (batch_size,)

            # get next actions
            epsilons_tensor = torch.randn_like(actions_tensor)
            actor.updateActionDist(next_states_tensor, preferences_tensor, epsilons_tensor)
            next_actions_tensor = actor.sample(deterministic=False)[0]

            # get targets
            next_reward_quantiles_tensor = reward_critic(next_states_tensor, next_actions_tensor, preferences_tensor).view(batch_size, reward_dim, -1)
            next_reward_quantiles = torch.sort(next_reward_quantiles_tensor, dim=-1)[0].detach().cpu().numpy() # (batch_size, reward_dim, n_critics*n_quantiles)
            reward_targets = []
            for reward_idx in range(reward_dim):
                reward_targets.append(self._getQuantileTargets(reward_vecs[:, reward_idx], dones, fails, rhos, next_reward_quantiles[:, reward_idx, :]))
            reward_targets = np.stack(reward_targets, axis=1) # (batch_size, reward_dim, n_target_quantiles)
            if cost_critic is not None:
                next_cost_quantiles_tensor = cost_critic(next_states_tensor, next_actions_tensor, preferences_tensor).view(batch_size, cost_dim, -1)
                next_cost_quantiles = torch.sort(next_cost_quantiles_tensor, dim=-1)[0].detach().cpu().numpy() # (batch_size, cost_dim, n_critics*n_quantiles)
                cost_targets = []
                for cost_idx in range(cost_dim):
                    cost_targets.append(self._getQuantileTargets(cost_vecs[:, cost_idx], dones, fails, rhos, next_cost_quantiles[:, cost_idx, :]))
                cost_targets = np.stack(cost_targets, axis=1)

            # append
            states_list.append(states)
            actions_list.append(actions)
            preferences_list.append(preferences)
            reward_targets_list.append(reward_targets)
            if cost_critic is not None:
                cost_targets_list.append(cost_targets)
            else:
                cost_targets_list.append(np.zeros_like(reward_targets))

        return np.concatenate(states_list, axis=0), np.concatenate(actions_list, axis=0), np.concatenate(preferences_list, axis=0), \
                np.concatenate(reward_targets_list, axis=0), np.concatenate(cost_targets_list, axis=0)