"""
TensorFlow policy class used for CQL.
"""
from functools import partial
import numpy as np
import gym
import logging
from typing import Dict, List, Type, Union

import ray
import ray.experimental.tf_utils
from src.rllib.agents.sac.sac_tf_policy import \
    apply_gradients as sac_apply_gradients, \
    compute_and_clip_gradients as sac_compute_and_clip_gradients,\
    get_distribution_inputs_and_class, build_sac_model, \
    postprocess_trajectory, setup_late_mixins, stats, validate_spaces, \
    ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin, \
    ComputeTDErrorMixin, TargetNetworkMixin
from src.rllib.models.modelv2 import ModelV2
from src.rllib.models.tf.tf_action_dist import TFActionDistribution
from src.rllib.policy.tf_policy_template import build_tf_policy
from src.rllib.policy.policy import Policy
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.utils.exploration.random import Random
from src.rllib.utils.framework import get_variable, \
    try_import_tf, try_import_tfp
from src.rllib.utils.typing import LocalOptimizer, ModelGradients, \
    TensorType, TrainerConfigDict

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()

logger = logging.getLogger(__name__)

MEAN_MIN = -9.0
MEAN_MAX = 9.0


# Returns policy tiled actions and log probabilities for CQL Loss
def policy_actions_repeat(model, action_dist, obs, num_repeat=1):
    obs_temp = tf.reshape(
        tf.tile(tf.expand_dims(obs, 1), [1, num_repeat, 1]),
        [-1, obs.shape[1]])
    logits = model.get_policy_output(obs_temp)
    policy_dist = action_dist(logits, model)
    actions, logp_ = policy_dist.sample_logp()
    logp = tf.expand_dims(logp_, -1)
    return actions, tf.reshape(logp, [tf.shape(obs)[0], num_repeat, 1])


def q_values_repeat(model, obs, actions, twin=False):
    action_shape = tf.shape(actions)[0]
    obs_shape = tf.shape(obs)[0]
    num_repeat = action_shape // obs_shape
    obs_temp = tf.reshape(
        tf.tile(tf.expand_dims(obs, 1), [1, num_repeat, 1]),
        [-1, tf.shape(obs)[1]])
    if not twin:
        preds_ = model.get_q_values(obs_temp, actions)
    else:
        preds_ = model.get_twin_q_values(obs_temp, actions)
    preds = tf.reshape(preds_, [tf.shape(obs)[0], num_repeat, 1])
    return preds


def cql_loss(policy: Policy, model: ModelV2,
             dist_class: Type[TFActionDistribution],
             train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    logger.info(f"Current iteration = {policy.cur_iter}")
    policy.cur_iter += 1

    # For best performance, turn deterministic off
    deterministic = policy.config["_deterministic_loss"]
    assert not deterministic
    twin_q = policy.config["twin_q"]
    discount = policy.config["gamma"]

    # CQL Parameters
    bc_iters = policy.config["bc_iters"]
    cql_temp = policy.config["temperature"]
    num_actions = policy.config["num_actions"]
    min_q_weight = policy.config["min_q_weight"]
    use_lagrange = policy.config["lagrangian"]
    target_action_gap = policy.config["lagrangian_thresh"]

    obs = train_batch[SampleBatch.CUR_OBS]
    actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)
    rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32)
    next_obs = train_batch[SampleBatch.NEXT_OBS]
    terminals = train_batch[SampleBatch.DONES]

    model_out_t, _ = model({
        "obs": obs,
        "is_training": True,
    }, [], None)

    model_out_tp1, _ = model({
        "obs": next_obs,
        "is_training": True,
    }, [], None)

    target_model_out_tp1, _ = policy.target_model({
        "obs": next_obs,
        "is_training": True,
    }, [], None)

    action_dist_t = policy.dist_class(
        model.get_policy_output(model_out_t), policy.model)
    policy_t, log_pis_t = action_dist_t.sample_logp()
    log_pis_t = tf.expand_dims(log_pis_t, -1)

    # Unlike original SAC, Alpha and Actor Loss are computed first.
    # Alpha Loss
    alpha_loss = -tf.reduce_mean(
        model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy))

    # Policy Loss (Either Behavior Clone Loss or SAC Loss)
    alpha = tf.math.exp(model.log_alpha)
    if policy.cur_iter >= bc_iters:
        min_q = model.get_q_values(model_out_t, policy_t)
        if twin_q:
            twin_q_ = model.get_twin_q_values(model_out_t, policy_t)
            min_q = tf.math.minimum(min_q, twin_q_)
        actor_loss = tf.reduce_mean(
            tf.stop_gradient(alpha) * log_pis_t - min_q)
    else:
        bc_logp = action_dist_t.logp(actions)
        actor_loss = tf.reduce_mean(
            tf.stop_gradient(alpha) * log_pis_t - bc_logp)
        # actor_loss = -tf.reduce_mean(bc_logp)

    # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
    # SAC Loss:
    # Q-values for the batched actions.
    action_dist_tp1 = policy.dist_class(
        model.get_policy_output(model_out_tp1), policy.model)
    policy_tp1, _ = action_dist_tp1.sample_logp()

    q_t = model.get_q_values(model_out_t, actions)
    q_t_selected = tf.squeeze(q_t, axis=-1)
    if twin_q:
        twin_q_t = model.get_twin_q_values(model_out_t, actions)
        twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1)

    # Target q network evaluation.
    q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
    if twin_q:
        twin_q_tp1 = policy.target_model.get_twin_q_values(
            target_model_out_tp1, policy_tp1)
        # Take min over both twin-NNs.
        q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1)

    q_tp1_best = tf.squeeze(input=q_tp1, axis=-1)
    q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best

    # compute RHS of bellman equation
    q_t_target = tf.stop_gradient(
        rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked)

    # Compute the TD-error (potentially clipped), for priority replay buffer
    base_td_error = tf.math.abs(q_t_selected - q_t_target)
    if twin_q:
        twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target)
    if twin_q:
        critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target)

    # CQL Loss (We are using Entropy version of CQL (the best version))
    rand_actions, _ = policy._random_action_generator.get_exploration_action(
        action_distribution=policy.dist_class(
            tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model),
        timestep=0,
        explore=True)
    curr_actions, curr_logp = policy_actions_repeat(model, policy.dist_class,
                                                    model_out_t, num_actions)
    next_actions, next_logp = policy_actions_repeat(model, policy.dist_class,
                                                    model_out_tp1, num_actions)

    q1_rand = q_values_repeat(model, model_out_t, rand_actions)
    q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
    q1_next_actions = q_values_repeat(model, model_out_t, next_actions)

    if twin_q:
        q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
        q2_curr_actions = q_values_repeat(
            model, model_out_t, curr_actions, twin=True)
        q2_next_actions = q_values_repeat(
            model, model_out_t, next_actions, twin=True)

    random_density = np.log(0.5**int(curr_actions.shape[-1]))
    cat_q1 = tf.concat([
        q1_rand - random_density,
        q1_next_actions - tf.stop_gradient(next_logp),
        q1_curr_actions - tf.stop_gradient(curr_logp)
    ], 1)
    if twin_q:
        cat_q2 = tf.concat([
            q2_rand - random_density,
            q2_next_actions - tf.stop_gradient(next_logp),
            q2_curr_actions - tf.stop_gradient(curr_logp)
        ], 1)

    min_qf1_loss_ = tf.reduce_mean(
        tf.reduce_logsumexp(cat_q1 / cql_temp,
                            axis=1)) * min_q_weight * cql_temp
    min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight)
    if twin_q:
        min_qf2_loss_ = tf.reduce_mean(
            tf.reduce_logsumexp(cat_q2 / cql_temp,
                                axis=1)) * min_q_weight * cql_temp
        min_qf2_loss = min_qf2_loss_ - (
            tf.reduce_mean(twin_q_t) * min_q_weight)

    if use_lagrange:
        alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0,
                                       1000000.0)[0]
        min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
        if twin_q:
            min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
            alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
        else:
            alpha_prime_loss = -min_qf1_loss

    cql_loss = [min_qf1_loss]
    if twin_q:
        cql_loss.append(min_qf2_loss)

    critic_loss = [critic_loss_1 + min_qf1_loss]
    if twin_q:
        critic_loss.append(critic_loss_2 + min_qf2_loss)

    # Save for stats function.
    policy.q_t = q_t_selected
    policy.policy_t = policy_t
    policy.log_pis_t = log_pis_t
    policy.td_error = td_error
    policy.actor_loss = actor_loss
    policy.critic_loss = critic_loss
    policy.alpha_loss = alpha_loss
    policy.log_alpha_value = model.log_alpha
    policy.alpha_value = alpha
    policy.target_entropy = model.target_entropy
    # CQL Stats
    policy.cql_loss = cql_loss
    if use_lagrange:
        policy.log_alpha_prime_value = model.log_alpha_prime[0]
        policy.alpha_prime_value = alpha_prime
        policy.alpha_prime_loss = alpha_prime_loss

    # Return all loss terms corresponding to our optimizers.
    if use_lagrange:
        return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + \
               alpha_prime_loss
    return actor_loss + tf.math.add_n(critic_loss) + alpha_loss


def cql_stats(policy: Policy,
              train_batch: SampleBatch) -> Dict[str, TensorType]:
    sac_dict = stats(policy, train_batch)
    sac_dict["cql_loss"] = tf.reduce_mean(tf.stack(policy.cql_loss))
    if policy.config["lagrangian"]:
        sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value
        sac_dict["alpha_prime_value"] = policy.alpha_prime_value
        sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss
    return sac_dict


class ActorCriticOptimizerMixin(SACActorCriticOptimizerMixin):
    def __init__(self, config):
        super().__init__(config)
        if config["lagrangian"]:
            # Eager mode.
            if config["framework"] in ["tf2", "tfe"]:
                self._alpha_prime_optimizer = tf.keras.optimizers.Adam(
                    learning_rate=config["optimization"][
                        "critic_learning_rate"])
            # Static graph mode.
            else:
                self._alpha_prime_optimizer = tf1.train.AdamOptimizer(
                    learning_rate=config["optimization"][
                        "critic_learning_rate"])


def setup_early_mixins(policy: Policy, obs_space: gym.spaces.Space,
                       action_space: gym.spaces.Space,
                       config: TrainerConfigDict) -> None:
    """Call mixin classes' constructors before Policy's initialization.

    Adds the necessary optimizers to the given Policy.

    Args:
        policy (Policy): The Policy object.
        obs_space (gym.spaces.Space): The Policy's observation space.
        action_space (gym.spaces.Space): The Policy's action space.
        config (TrainerConfigDict): The Policy's config.
    """
    policy.cur_iter = 0
    ActorCriticOptimizerMixin.__init__(policy, config)
    if config["lagrangian"]:
        policy.model.log_alpha_prime = get_variable(
            0.0, framework="tf", trainable=True, tf_name="log_alpha_prime")
        policy.alpha_prime_optim = tf.keras.optimizers.Adam(
            learning_rate=config["optimization"]["critic_learning_rate"], )
    # Generic random action generator for calculating CQL-loss.
    policy._random_action_generator = Random(
        action_space,
        model=None,
        framework="tf2",
        policy_config=config,
        num_workers=0,
        worker_index=0)


def compute_gradients_fn(policy: Policy, optimizer: LocalOptimizer,
                         loss: TensorType) -> ModelGradients:
    grads_and_vars = sac_compute_and_clip_gradients(policy, optimizer, loss)

    if policy.config["lagrangian"]:
        # Eager: Use GradientTape (which is a property of the `optimizer`
        # object (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py).
        if policy.config["framework"] in ["tf2", "tfe"]:
            tape = optimizer.tape
            log_alpha_prime = [policy.model.log_alpha_prime]
            alpha_prime_grads_and_vars = list(
                zip(
                    tape.gradient(policy.alpha_prime_loss, log_alpha_prime),
                    log_alpha_prime))
        # Tf1.x: Use optimizer.compute_gradients()
        else:
            alpha_prime_grads_and_vars = \
                policy._alpha_prime_optimizer.compute_gradients(
                    policy.alpha_prime_loss,
                    var_list=[policy.model.log_alpha_prime])

        # Clip if necessary.
        if policy.config["grad_clip"]:
            clip_func = partial(
                tf.clip_by_norm, clip_norm=policy.config["grad_clip"])
        else:
            clip_func = tf.identity

        # Save grads and vars for later use in `build_apply_op`.
        policy._alpha_prime_grads_and_vars = [
            (clip_func(g), v) for (g, v) in alpha_prime_grads_and_vars
            if g is not None
        ]

        grads_and_vars += policy._alpha_prime_grads_and_vars
    return grads_and_vars


def apply_gradients_fn(policy, optimizer, grads_and_vars):
    sac_results = sac_apply_gradients(policy, optimizer, grads_and_vars)

    if policy.config["lagrangian"]:
        # Eager mode -> Just apply and return None.
        if policy.config["framework"] in ["tf2", "tfe"]:
            policy._alpha_prime_optimizer.apply_gradients(
                policy._alpha_prime_grads_and_vars)
            return
        # Tf static graph -> Return grouped op.
        else:
            alpha_prime_apply_op = \
                policy._alpha_prime_optimizer.apply_gradients(
                    policy._alpha_prime_grads_and_vars,
                    global_step=tf1.train.get_or_create_global_step())
            return tf.group([sac_results, alpha_prime_apply_op])
    return sac_results


# Build a child class of `TFPolicy`, given the custom functions defined
# above.
CQLTFPolicy = build_tf_policy(
    name="CQLTFPolicy",
    loss_fn=cql_loss,
    get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
    validate_spaces=validate_spaces,
    stats_fn=cql_stats,
    postprocess_fn=postprocess_trajectory,
    before_init=setup_early_mixins,
    after_init=setup_late_mixins,
    make_model=build_sac_model,
    mixins=[
        ActorCriticOptimizerMixin, TargetNetworkMixin, ComputeTDErrorMixin
    ],
    action_distribution_fn=get_distribution_inputs_and_class,
    compute_gradients_fn=compute_gradients_fn,
    apply_gradients_fn=apply_gradients_fn,
)
