import numpy as np
import scipy
import gym
import random
from collections import namedtuple


def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.
    input:
        vector x,
        [x0,
         x1,
         x2]
    output:
        [x0 + discount * x1 + discount^2 * x2,
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]


def compute_gae(next_value, rewards, masks, values, gamma=0.99, lam=0.95):
    values = values + [next_value]
    gae = 0
    returns = []
    advantages = np.zeros_like(rewards)
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * lam * masks[step] * gae
        advantages[step] = gae
        
        returns.insert(0, gae + values[step])
    return np.array(returns), advantages


def compute_gae_sb(last_values, rewards, dones, episode_starts, values, gamma, lam): 
    # Convert to numpy
    last_values = last_values.clone().cpu().numpy().flatten()

    last_gae_lam = 0
    for step in reversed(range(len(rewards))):
        if step == len(rewards) - 1:
            next_non_terminal = 1.0 - dones
            next_values = last_values
        else:
            next_non_terminal = 1.0 - self.episode_starts[step + 1]
            next_values = self.values[step + 1]
        delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
        last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
        self.advantages[step] = last_gae_lam
    # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
    # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
    self.returns = self.advantages + self.values

def get_shapes(env):
    # get action shape
    if hasattr(env, 'shape_a'):
        shape_a = env.shape_a
    elif hasattr(env, 'action_space'):
        if isinstance(env.action_space, gym.spaces.Discrete):
            shape_a = (1,)
        elif isinstance(env.action_space, gym.spaces.Box):
            shape_a = env.action_space.shape

    # get observation shape
    if hasattr(env, 'shape_obs'):
        shape_obs = env.shape_obs
    elif hasattr(env, 'observation_space'):
        if isinstance(env.observation_space, gym.spaces.Discrete):
            shape_obs = (1,)
        elif isinstance(env.observation_space, gym.spaces.Box):
            shape_obs = env.observation_space.shape

    return shape_obs, shape_a


def get_buf_dtype(env):
    shape_obs, shape_a = get_shapes(env)
    buf_dtype = np.dtype([('o', np.float32, shape_obs),  # observation
                          ('r', np.float32),  # collected reward
                          ('r_gt', np.float32),  # collected ground truth reward
                          ('r_pref', np.float32),  # collected ground truth reward
                          ('v', np.float32),  # value estimate output by network
                          ('a', np.float32, shape_a),  # action
                          ('a_logp', np.float32),  # action log probability
                          ('adv', np.float32),  # advantage estimate (computed before update)
                          ('v_tgt', np.float32)])  # value target (computed before update)

    return buf_dtype


class Buffer(dict):
    def __init__(self, buf_size, buf_dtype):
        self.buf = np.empty(buf_size, dtype=buf_dtype)
        self.buf_size = buf_size
        self.buf_dtype = buf_dtype
        self.ptr = 0  # buffer position where next item will be stored

    def store(self, buf_item):
        self.buf[self.ptr] = buf_item
        self.ptr += 1
        if self.ptr == self.buf_size:
            self.ptr = 0

    def reset(self):
        self.ptr = 0

    def get_batch(self):
        pass

    def start_traj(self):
        self.traj_start_ptr = self.ptr

    def finish_traj(self, gamma, lambd, last_v=0):
        traj_slice = slice(self.traj_start_ptr, self.ptr)
        rews = np.append(self.buf['r'][traj_slice], last_v)
        vals = np.append(self.buf['v'][traj_slice], last_v)

        # GAE-Lambda advantage calculation
        deltas = rews[:-1] + gamma * vals[1:] - vals[:-1]
        self.buf['adv'][traj_slice] = discount_cumsum(deltas, gamma * lambd)

        # Value targets
        self.buf['v_tgt'][traj_slice] = discount_cumsum(rews, gamma)[:-1]


Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'gt_reward', 'done', 'irl_reward',
                         'pref_reward', 'v_pred', 'v_tgt', 'adv', 'logp_a'))


class ReplayMemory(object):
    """
        The memory object for off-policy RL algorithms
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def clear(self):
        self.memory = []
        self.position = 0

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def iter(self, batch_size):
        for i in range(len(self.memory) // batch_size):
            yield self.memory[i:i+batch_size]

    def __len__(self):
        return len(self.memory)


class RolloutBuffer(object):
    """
        The memory object for on-policy RL algorithms (PG)
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def shuffle(self):
        np.random.shuffle(self.memory)

    def clear(self):
        self.memory = []
        self.position = 0

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def iter(self, batch_size):
        for i in range(len(self.memory) // batch_size):
            yield self.memory[i:i+batch_size]

    def __len__(self):
        return len(self.memory)
