import numpy as np
import tensorflow as tf
import gym
import time
from spinup.algos.tf1.mc_discrete_sac import sac_core as core
from spinup.algos.tf1.mc_discrete_sac.sac_core import *
from spinup.utils.logx import EpochLogger


class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for SAC agents.
    """

    def __init__(self, obs_dim, act_dim, size):
        self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
        self.rews_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        self.obs1_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return dict(obs1=self.obs1_buf[idxs],
                    obs2=self.obs2_buf[idxs],
                    acts=self.acts_buf[idxs],
                    rews=self.rews_buf[idxs],
                    done=self.done_buf[idxs])

"""
Store the observations in ring buffer type array of size m
"""
class StateBuffer:
    def __init__(self,m):
        self.m = m

    def init_state(self, init_obs):
        self.current_state = np.concatenate([init_obs]*self.m, axis=0)
        return self.current_state

    def append_state(self, obs):
        new_state = np.concatenate( (self.current_state, obs), axis=0)
        self.current_state = new_state[obs.shape[0]:]
        return self.current_state

"""
Process features of the environment
"""
def process_observation(o, obs_dim, observation_type):
    if observation_type == 'Discrete':
        o = np.eye(obs_dim)[o]
    return o

def process_action(a, act_dim):
    one_hot_a = np.eye(act_dim)[a]
    return one_hot_a

def process_reward(reward):
    # apply clipping here if needed
    return reward

"""
Linear annealing from start to stop value based on current step and max_steps
"""
def linear_anneal(current_step, start=0.1, stop=1.0, steps=1e6):
    if current_step<=steps:
        eps = stop + (start - stop) * (1 - current_step/steps)
    else:
        eps=start
    return eps

"""
Clip gradient whilst handling None error
"""
def ClipIfNotNone(grad, grad_clip_val):
    if grad is None:
        return grad
    return tf.clip_by_value(grad, -grad_clip_val, grad_clip_val)



def sac(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0, 
        steps_per_epoch=4000, epochs=100, replay_size=int(1e6), gamma=0.99, 
        polyak=0.995, lr=1e-3, batch_size=100, start_steps=10000, 
        update_after=1000, update_every=50, num_test_episodes=10, max_ep_len=1000, 
        logger_kwargs=dict(), save_freq=1,
        a_hidden_sizes=(256,256), q_hidden_sizes=(256, 256),
        state_hist_n= 1 ,
        grad_clip_val=None,
        alpha= 'auto',
        target_entropy_start=0.3, # proportion of max_entropy
        target_entropy_stop=0.3,
        target_entropy_steps=1e5):
    """
    Soft Actor-Critic (SAC)


    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``mu``       (batch, act_dim)  | Computes mean actions from policy
                                           | given states.
            ``pi``       (batch, act_dim)  | Samples actions from policy given 
                                           | states.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``. Critical: must be differentiable
                                           | with respect to policy parameters all
                                           | the way through action sampling.
            ``q1``       (batch,)          | Gives one estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q2``       (batch,)          | Gives another estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target 
            networks. Target networks are updated towards main networks 
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow 
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually 
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to 
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        update_after (int): Number of env interactions to collect before
            starting to do gradient descent updates. Ensures replay buffer
            is full enough for useful updates.

        update_every (int): Number of env interactions that should elapse
            between gradient descent updates. Note: Regardless of how long 
            you wait between updates, the ratio of env steps to gradient steps 
            is locked to 1.

        num_test_episodes (int): Number of episodes to test the deterministic
            policy at the end of each epoch.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    obs_space = env.observation_space
    act_space = env.action_space
    try:
        obs_dim = obs_space.n
        observation_type = 'Discrete'
    except AttributeError as e:
        obs_dim = obs_space.shape[0]
        observation_type = 'Box'
    
    act_dim = act_space.n

    NUM_STYLES = 2

    # init a state buffer for storing last m states
    train_state_buffer = StateBuffer(m=state_hist_n)
    test_state_buffer  = StateBuffer(m=state_hist_n)

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim*state_hist_n, act_dim=act_dim, size=replay_size)

    # Inputs to computation graph
    x_ph, a_ph, x2_ph, r_ph, d_ph = placeholders(obs_dim*state_hist_n, act_dim, obs_dim*state_hist_n, None, None)

    # alpha and entropy setup
    max_target_entropy = tf.log(tf.cast(act_dim, tf.float32))
    target_entropy_prop_ph =  tf.placeholder(dtype=tf.float32, shape=())
    target_entropy = max_target_entropy * target_entropy_prop_ph

    log_alpha = tf.get_variable('log_alpha', dtype=tf.float32, initializer=0.0)

    if alpha == 'auto': # auto tune alpha
        alpha = tf.exp(log_alpha)
    else: # fixed alpha
        alpha = tf.get_variable('alpha', dtype=tf.float32, initializer=alpha)

    # Main outputs from computation graph
    with tf.variable_scope('main'):
        mu, pi, action_probs, log_action_probs, q1_logits, q2_logits, q1_a, q2_a = actor_critic(x_ph, a_ph, num_styles=NUM_STYLES, **ac_kwargs)

    with tf.variable_scope('main', reuse=True):
        _, _, action_probs_next, log_action_probs_next, _, _, _, _  =  actor_critic(x2_ph, a_ph, num_styles=NUM_STYLES, **ac_kwargs)

    # Target value network
    with tf.variable_scope('target'):
        # dont need to pass pi_next in here as we don't need to sample q for policy as we have policy distribution
        # just use a_ph as it doesn't affect anything
        _, _, _, _, q1_logits_targ, q2_logits_targ, _, _ = actor_critic(x2_ph, a_ph, num_styles=NUM_STYLES, **ac_kwargs)

    # Count variables
    var_counts = tuple(count_vars(scope) for scope in ['log_alpha',
                                                       'main/pi',
                                                       'main/q1',
                                                       'main/q2',
                                                       'main'])
    print("""\nNumber of other parameters:
             alpha: %d,
             pi: %d,
             q1: %d,
             q2: %d,
             total: %d\n"""%var_counts)

    # ------------ Multi-Q ------------
    # Get setting info based on state:
    x1_setting = x_ph[:,0]
    x2_setting = x2_ph[:,0]

    # Min Double-Q:
    # - computed independently for each style
    # - creates a dictionary in the same format as the Q-nets dict from core2
    min_q_logits = {}
    min_q_logits_targ = {}
    for st in range(NUM_STYLES):
        min_q_logits['style_' + str(st)]       = tf.minimum(q1_logits['style_' + str(st)], q2_logits['style_' + str(st)])
        min_q_logits_targ['style_' + str(st)]  = tf.minimum(q1_logits_targ['style_' + str(st)], q2_logits_targ['style_' + str(st)])

    # Piecewise Style selection:
    """ NOTE: THIS PART IS CURRENTLY NOT GENERALIZED FOR MORE THAN 2 STYLES """
    min_q_logits_selection = tf.where(tf.greater(x1_setting, 0), min_q_logits['style_0'], min_q_logits['style_1'])
    min_q_logits_targ_selection = tf.where(tf.greater(x2_setting, 0), min_q_logits_targ['style_0'], min_q_logits_targ['style_1'])

    q1_a_selection = tf.where(tf.greater(x1_setting, 0), q1_a['style_0'], q1_a['style_1'])
    q2_a_selection = tf.where(tf.greater(x1_setting, 0), q2_a['style_0'], q2_a['style_1'])

    # Targets for Q regression
    q_backup = r_ph + gamma*(1-d_ph)*tf.stop_gradient( tf.reduce_sum(action_probs_next * (min_q_logits_targ_selection - alpha * log_action_probs_next), axis=-1))

    # critic losses
    q1_loss = 0.5 * tf.reduce_mean((q_backup - q1_a_selection)**2)
    q2_loss = 0.5 * tf.reduce_mean((q_backup - q2_a_selection)**2)
    value_loss = q1_loss + q2_loss

    # policy loss
    pi_backup = tf.reduce_sum(action_probs * ( alpha * log_action_probs - min_q_logits_selection ), axis=-1)
    pi_loss = tf.reduce_mean(pi_backup)

    # alpha loss for temperature parameter
    pi_entropy = -tf.reduce_sum(action_probs * log_action_probs, axis=-1)
    alpha_backup = tf.stop_gradient(target_entropy - pi_entropy)
    alpha_loss   = -tf.reduce_mean(log_alpha * alpha_backup)

    # Policy train op
    # (has to be separate from value train op, because q1_logits appears in pi_loss)
    pi_optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-04)
    if grad_clip_val is not None:
        gvs = pi_optimizer.compute_gradients(pi_loss,  var_list=get_vars('main/pi'))
        capped_gvs = [(ClipIfNotNone(grad, grad_clip_val), var) for grad, var in gvs]
        train_pi_op = pi_optimizer.apply_gradients(capped_gvs)
    else:
        train_pi_op = pi_optimizer.minimize(pi_loss, var_list=get_vars('main/pi'))

    # Value train op
    # (control dep of train_pi_op because sess.run otherwise evaluates in nondeterministic order)
    value_optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-04)
    with tf.control_dependencies([train_pi_op]):
        if grad_clip_val is not None:
            gvs = value_optimizer.compute_gradients(value_loss, var_list=get_vars('main/q'))
            capped_gvs = [(ClipIfNotNone(grad, grad_clip_val), var) for grad, var in gvs]
            train_value_op = value_optimizer.apply_gradients(capped_gvs)
        else:
            train_value_op = value_optimizer.minimize(value_loss, var_list=get_vars('main/q'))

    alpha_optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-04)
    with tf.control_dependencies([train_value_op]):
        train_alpha_op = alpha_optimizer.minimize(alpha_loss, var_list=get_vars('log_alpha'))

    # Polyak averaging for target variables
    # (control flow because sess.run otherwise evaluates in nondeterministic order)
    with tf.control_dependencies([train_value_op]):
        target_update = tf.group([tf.assign(v_targ, polyak*v_targ + (1-polyak)*v_main)
                                  for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])

    # All ops to call during one training step
    step_ops = [pi_loss, q1_loss, q2_loss, q1_a_selection, q2_a_selection,
                pi_entropy, target_entropy,
                alpha_loss, alpha,
                train_pi_op, train_value_op, train_alpha_op, target_update]

    # Initializing targets to match main variables
    target_init = tf.group([tf.assign(v_targ, v_main)
                              for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run(target_init)

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x': x_ph, 'a': a_ph},
                                outputs={'mu': mu, 'pi': pi, 'q1_a': q1_a_selection, 'q2_a': q2_a_selection})

    def get_action(state, deterministic=False):
        act_op = mu if deterministic else pi
        return sess.run(act_op, feed_dict={x_ph: [state]})[0]

    def reset(env, state_buffer):
        o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
        o = process_observation(o, obs_dim, observation_type)
        r = process_reward(r)
        state = state_buffer.init_state(init_obs=o)
        return o, r, d, ep_ret, ep_len, state

    def test_agent():
        for j in range(num_test_episodes):
            o, r, d, ep_ret, ep_len, test_state = reset(test_env, test_state_buffer)

            # if render: test_env.render()

            while not(d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, _, d, info = test_env.step(get_action(test_state, True))
                if o[0] > 0.1:
                    r = info['RewardBreakdown'][0]
                else:
                    r = info['RewardBreakdown'][1]
                o = process_observation(o, obs_dim, observation_type)
                r = process_reward(r)
                test_state = test_state_buffer.append_state(o)
                ep_ret += r
                ep_len += 1

                # if render: test_env.render()

            # if render: test_env.close()
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

    start_time = time.time()
    o, r, d, ep_ret, ep_len, state = reset(env, train_state_buffer)
    total_steps = steps_per_epoch * epochs

    target_entropy_prop = linear_anneal(current_step=0, start=target_entropy_start, stop=target_entropy_stop, steps=target_entropy_steps)

    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):
        """
        Until start_steps have elapsed, randomly sample actions
        from a uniform distribution for better exploration. Afterwards,
        use the learned policy.
        """
        if t > start_steps:
            a = get_action(state)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, _, d, info = env.step(a)
        if o2[0] > 0.1:
            r = info['RewardBreakdown'][0]
        else:
            r = info['RewardBreakdown'][1]
        o2 = process_observation(o2, obs_dim, observation_type)
        a = process_action(a, act_dim)
        r = process_reward(r)
        next_state = train_state_buffer.append_state(o2)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len==max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(state, a, r, next_state, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2
        state = next_state

        if d or (ep_len == max_ep_len):
            """
            Perform all SAC updates at the end of the trajectory.
            This is a slight difference from the SAC specified in the
            original paper.
            """
            for j in range(ep_len):
                batch = replay_buffer.sample_batch(batch_size)
                feed_dict = {x_ph:  batch['obs1'],
                             x2_ph: batch['obs2'],
                             a_ph:  batch['acts'],
                             r_ph:  batch['rews'],
                             d_ph:  batch['done'],
                             target_entropy_prop_ph: target_entropy_prop
                            }

                outs = sess.run(step_ops, feed_dict)
                logger.store(LossPi=outs[0],
                             LossQ1=outs[1],    LossQ2=outs[2],
                             Q1Vals=outs[3],    Q2Vals=outs[4],
                             PiEntropy=outs[5], TargEntropy=outs[6],
                             LossAlpha=outs[7], Alpha=outs[8])

            logger.store(EpRet=ep_ret, EpLen=ep_len)
            o, r, d, ep_ret, ep_len, state = reset(env, train_state_buffer)


        # End of epoch wrap-up
        if t > 0 and t % steps_per_epoch == 0:
            epoch = t // steps_per_epoch

            # update target entropy every epoch
            target_entropy_prop = linear_anneal(current_step=t, start=target_entropy_start, stop=target_entropy_stop, steps=target_entropy_steps)

            # Save model
            if (epoch % save_freq == 0) or (epoch == epochs-1):
                logger.save_state({'env': env}, None)

            # Test the performance of the deterministic version of the agent.
            test_agent()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('PiEntropy', average_only=True)
            logger.log_tabular('TargEntropy', average_only=True)
            logger.log_tabular('Alpha', average_only=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ1', average_only=True)
            logger.log_tabular('LossQ2', average_only=True)
            logger.log_tabular('LossAlpha', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='HalfCheetah-v2')
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    # parser.add_argument('--a_hid_size', nargs='+', type=int, default=[32,32])
    # parser.add_argument('--q_hid_size', nargs='+', type=int, default=[64,64])
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='discrete_sac')

    args = parser.parse_args()

    from rl_smoothness.utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    sac(lambda : gym.make(args.env), actor_critic=core.mlp_actor_critic,
        a_hidden_sizes=args.a_hid_sze, q_hidden_sizes=args.q_hid_size,
        gamma=args.gamma, seed=args.seed, epochs=args.epochs,
        logger_kwargs=logger_kwargs)
