# Copyright 2022 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Proximal policy optimization training.

This is modified from the original Brax PPO to be vectorisable

See: https://arxiv.org/pdf/1707.06347.pdf
"""

import functools
import os
import time
from typing import Any, Callable, Dict, Optional, Tuple

from absl import logging
from brax import envs
from brax.io import model
from brax.training import distribution
from brax.training import networks
from brax.training import normalization
from brax.training import pmap
from brax.training.types import Params
from brax.training.types import PRNGKey
import flax
import jax
import jax.numpy as jnp
import optax


def compute_gae(truncation: jnp.ndarray, termination: jnp.ndarray, rewards: jnp.ndarray, values: jnp.ndarray, bootstrap_value: jnp.ndarray, lambda_: float = 1.0, discount: float = 0.99):
    r"""Calculates the Generalized Advantage Estimation (GAE).

    Args:
      truncation: A float32 tensor of shape [T, B] with truncation signal.
      termination: A float32 tensor of shape [T, B] with termination signal.
      rewards: A float32 tensor of shape [T, B] containing rewards generated by
        following the behaviour policy.
      values: A float32 tensor of shape [T, B] with the value function estimates
        wrt. the target policy.
      bootstrap_value: A float32 of shape [B] with the value function estimate at
        time T.
      lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). Defaults to
        lambda_=1.
      discount: TD discount.

    Returns:
      A float32 tensor of shape [T, B]. Can be used as target to
        train a baseline (V(x_t) - vs_t)^2.
      A float32 tensor of shape [T, B] of advantages.
    """

    truncation_mask = 1 - truncation
    # Append bootstrapped value to get [v1, ..., v_t+1]
    values_t_plus_1 = jnp.concatenate([values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0)
    deltas = rewards + discount * (1 - termination) * values_t_plus_1 - values
    deltas *= truncation_mask

    acc = jnp.zeros_like(bootstrap_value)
    vs_minus_v_xs = []

    def compute_vs_minus_v_xs(carry, target_t):
        lambda_, acc = carry
        truncation_mask, delta, termination = target_t
        acc = delta + discount * (1 - termination) * truncation_mask * lambda_ * acc
        return (lambda_, acc), (acc)

    (_, _), (vs_minus_v_xs) = jax.lax.scan(compute_vs_minus_v_xs, (lambda_, acc), (truncation_mask, deltas, termination), length=int(truncation_mask.shape[0]), reverse=True)
    # Add V(x_s) to get v_s.
    vs = jnp.add(vs_minus_v_xs, values)

    vs_t_plus_1 = jnp.concatenate([vs[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0)
    advantages = (rewards + discount * (1 - termination) * vs_t_plus_1 - values) * truncation_mask
    return jax.lax.stop_gradient(vs), jax.lax.stop_gradient(advantages)


@flax.struct.dataclass
class StepData:
    """Contains data for one environment step."""

    obs: jnp.ndarray
    rewards: jnp.ndarray
    dones: jnp.ndarray
    truncation: jnp.ndarray
    actions: jnp.ndarray
    logits: jnp.ndarray


@flax.struct.dataclass
class TrainingState:
    """Contains training state for the learner."""

    optimizer_state: optax.OptState
    params: Params
    key: PRNGKey
    normalizer_params: Params


def compute_drift_loss(
    models: Dict[str, Params],
    data: StepData,
    rng: jnp.ndarray,
    parametric_action_distribution: distribution.ParametricDistribution,
    policy_apply: Any,
    value_apply: Any,
    drift_apply: Any,
    drift_params: Params,
    entropy_cost: float = 1e-4,
    discounting: float = 0.9,
    reward_scaling: float = 1.0,
    lambda_: float = 0.95,
    ppo_init: bool = False,
):
    """Computes PPO loss."""
    policy_params, value_params = models["policy"], models["value"]
    policy_logits = policy_apply(policy_params, data.obs[:-1])
    baseline = value_apply(value_params, data.obs)
    baseline = jnp.squeeze(baseline, axis=-1)

    # Use last baseline value (from the value function) to bootstrap.
    bootstrap_value = baseline[-1]
    baseline = baseline[:-1]

    # At this point, we have unroll length + 1 steps. The last step is only used
    # as bootstrap value, so it's removed.

    # already removed at data generation time
    # actions = actions[:-1]
    # logits = logits[:-1]

    rewards = data.rewards[1:] * reward_scaling
    truncation = data.truncation[1:]
    termination = data.dones[1:] * (1 - truncation)

    target_action_log_probs = parametric_action_distribution.log_prob(policy_logits, data.actions)
    behaviour_action_log_probs = parametric_action_distribution.log_prob(data.logits, data.actions)

    vs, advantages = compute_gae(truncation=truncation, termination=termination, rewards=rewards, values=baseline, bootstrap_value=bootstrap_value, lambda_=lambda_, discount=discounting)
    rho_s = jnp.exp(target_action_log_probs - behaviour_action_log_probs)

    # surrogate_loss1 = rho_s * advantages
    # surrogate_loss2 = jnp.clip(rho_s, 1 - ppo_epsilon, 1 + ppo_epsilon) * advantages
    # policy_loss = -jnp.mean(jnp.minimum(surrogate_loss1, surrogate_loss2))

    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    diff = target_action_log_probs - behaviour_action_log_probs
    diff_exp = 1 - jnp.exp(target_action_log_probs - behaviour_action_log_probs)
    deltas = jnp.stack([diff, diff * diff, diff_exp, diff_exp * diff_exp], axis=-1)

    A_deltas = jnp.expand_dims(advantages, -1) * deltas
    drift_input = jnp.concatenate((deltas, A_deltas), axis=-1)
    drift = drift_apply(drift_params, drift_input).squeeze(-1)

    if ppo_init:
        eps = 0.2
        drift_ppo = flax.linen.relu((rho_s - jnp.clip(rho_s, a_min=1 - eps, a_max=1 + eps)) * advantages)
        drift = flax.linen.relu(drift_ppo + drift - 0.0001)
    else:
        drift = flax.linen.relu(drift - 0.0001)

    policy_loss = -(rho_s * advantages - drift).mean()

    # Value function loss
    v_error = vs - baseline
    v_loss = jnp.mean(v_error * v_error) * 0.5 * 0.5

    # Entropy reward
    entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng))
    entropy_loss = entropy_cost * -entropy

    clipfrac = (jnp.abs(rho_s - 1.0) > 0.3).mean()
    return policy_loss + v_loss + entropy_loss, {"clipfrac": clipfrac, "total_loss": policy_loss + v_loss + entropy_loss, "policy_loss": policy_loss, "v_loss": v_loss, "entropy_loss": entropy_loss}
    # return policy_loss + v_loss + entropy_loss, {"total_loss": policy_loss + v_loss + entropy_loss, "policy_loss": policy_loss, "v_loss": v_loss, "entropy_loss": entropy_loss}


def train(
    key,
    drift_params,
    environment_fn: Callable[..., envs.Env],
    num_timesteps,
    drift_apply,
    episode_length: int,
    action_repeat: int = 1,
    num_envs: int = 1,
    num_eval_envs: int = 128,
    learning_rate=1e-4,
    entropy_cost=1e-4,
    discounting=0.9,
    unroll_length=10,
    batch_size=32,
    num_minibatches=16,
    num_update_epochs=2,
    log_frequency=10,
    normalize_observations=False,
    reward_scaling=1.0,
    ppo_init=True,
):
    """PPO training."""
    assert batch_size * num_minibatches % num_envs == 0

    key, key_models, key_env, key_eval, key_debug = jax.random.split(key, 5)

    core_env = environment_fn(action_repeat=action_repeat, batch_size=num_envs, episode_length=episode_length)
    step_fn = jax.jit(core_env.step)
    reset_fn = jax.jit(core_env.reset)
    first_state = reset_fn(key_env)

    eval_env = environment_fn(action_repeat=action_repeat, batch_size=num_eval_envs, episode_length=episode_length, eval_metrics=True)
    eval_step_fn = jax.jit(eval_env.step)
    eval_first_state = jax.jit(eval_env.reset)(key_eval)

    parametric_action_distribution = distribution.NormalTanhDistribution(event_size=core_env.action_size)

    policy_model, value_model = networks.make_models(parametric_action_distribution.param_size, core_env.observation_size)
    key_policy, key_value = jax.random.split(key_models)

    optimizer = optax.adam(learning_rate=learning_rate)
    init_params = {"policy": policy_model.init(key_policy), "value": value_model.init(key_value)}
    optimizer_state = optimizer.init(init_params)

    normalizer_params, obs_normalizer_update_fn, obs_normalizer_apply_fn = normalization.create_observation_normalizer(core_env.observation_size, normalize_observations, num_leading_batch_dims=2)

    loss_fn = functools.partial(
        compute_drift_loss,
        drift_params=drift_params,
        drift_apply=drift_apply,
        parametric_action_distribution=parametric_action_distribution,
        policy_apply=policy_model.apply,
        value_apply=value_model.apply,
        entropy_cost=entropy_cost,
        discounting=discounting,
        reward_scaling=reward_scaling,
        ppo_init=ppo_init,
    )

    grad_loss = jax.grad(loss_fn, has_aux=True)

    def do_one_step_eval(carry, unused_target_t):
        state, policy_params, normalizer_params, key = carry
        key, key_sample = jax.random.split(key)
        # TODO: Make this nicer ([0] comes from pmapping).
        # obs = obs_normalizer_apply_fn(jax.tree_map(lambda x: x[0], normalizer_params), state.obs)
        obs = obs_normalizer_apply_fn(normalizer_params, state.obs)
        logits = policy_model.apply(policy_params, obs)
        actions = parametric_action_distribution.sample(logits, key_sample)
        nstate = eval_step_fn(state, actions)
        return (nstate, policy_params, normalizer_params, key), ()

    @jax.jit
    def run_eval(state, key, policy_params, normalizer_params) -> Tuple[envs.State, PRNGKey]:
        # policy_params = jax.tree_map(lambda x: x[0], policy_params)
        (state, _, _, key), _ = jax.lax.scan(do_one_step_eval, (state, policy_params, normalizer_params, key), (), length=episode_length // action_repeat)
        return state, key

    def do_one_step(carry, unused_target_t):
        state, normalizer_params, policy_params, key = carry
        key, key_sample = jax.random.split(key)
        normalized_obs = obs_normalizer_apply_fn(normalizer_params, state.obs)
        logits = policy_model.apply(policy_params, normalized_obs)
        actions = parametric_action_distribution.sample_no_postprocessing(logits, key_sample)
        postprocessed_actions = parametric_action_distribution.postprocess(actions)
        nstate = step_fn(state, postprocessed_actions)
        return (nstate, normalizer_params, policy_params, key), StepData(obs=state.obs, rewards=state.reward, dones=state.done, truncation=state.info["truncation"], actions=actions, logits=logits)

    def generate_unroll(carry, unused_target_t):
        state, normalizer_params, policy_params, key = carry
        (state, _, _, key), data = jax.lax.scan(do_one_step, (state, normalizer_params, policy_params, key), (), length=unroll_length)
        data = data.replace(
            obs=jnp.concatenate([data.obs, jnp.expand_dims(state.obs, axis=0)]),
            rewards=jnp.concatenate([data.rewards, jnp.expand_dims(state.reward, axis=0)]),
            dones=jnp.concatenate([data.dones, jnp.expand_dims(state.done, axis=0)]),
            truncation=jnp.concatenate([data.truncation, jnp.expand_dims(state.info["truncation"], axis=0)]),
        )
        return (state, normalizer_params, policy_params, key), data

    def update_model(carry, data):
        optimizer_state, params, key = carry
        key, key_loss = jax.random.split(key)
        loss_grad, metrics = grad_loss(params, data, key_loss)
        # loss_grad = jax.lax.pmean(loss_grad, axis_name="i")

        params_update, optimizer_state = optimizer.update(loss_grad, optimizer_state)
        params = optax.apply_updates(params, params_update)

        return (optimizer_state, params, key), metrics

    def minimize_epoch(carry, unused_t):
        optimizer_state, params, data, key = carry
        key, key_perm, key_grad = jax.random.split(key, 3)
        permutation = jax.random.permutation(key_perm, data.obs.shape[1])

        def convert_data(data, permutation):
            data = jnp.take(data, permutation, axis=1, mode="clip")
            data = jnp.reshape(data, [data.shape[0], num_minibatches, -1] + list(data.shape[2:]))
            data = jnp.swapaxes(data, 0, 1)
            return data

        ndata = jax.tree_map(lambda x: convert_data(x, permutation), data)
        (optimizer_state, params, _), metrics = jax.lax.scan(update_model, (optimizer_state, params, key_grad), ndata, length=num_minibatches)
        return (optimizer_state, params, data, key), metrics

    def run_epoch(carry: Tuple[TrainingState, envs.State], unused_t):
        training_state, state = carry
        key_minimize, key_generate_unroll, new_key = jax.random.split(training_state.key, 3)
        (state, _, _, _), data = jax.lax.scan(generate_unroll, (state, training_state.normalizer_params, training_state.params["policy"], key_generate_unroll), (), length=batch_size * num_minibatches // num_envs)
        # make unroll first
        data = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data)
        data = jax.tree_map(lambda x: jnp.reshape(x, [x.shape[0], -1] + list(x.shape[3:])), data)

        # Update normalization params and normalize observations.
        normalizer_params = obs_normalizer_update_fn(training_state.normalizer_params, data.obs[:-1])
        data = data.replace(obs=obs_normalizer_apply_fn(normalizer_params, data.obs))

        (optimizer_state, params, _, _), metrics = jax.lax.scan(minimize_epoch, (training_state.optimizer_state, training_state.params, data, key_minimize), (), length=num_update_epochs)

        new_training_state = TrainingState(optimizer_state=optimizer_state, params=params, normalizer_params=normalizer_params, key=new_key)
        return (new_training_state, state), metrics

    num_epochs = num_timesteps // (batch_size * unroll_length * num_minibatches * action_repeat)

    def minimize_loop(training_state, state):
        return jax.lax.scan(run_epoch, (training_state, state), (), length=num_epochs // log_frequency)

    # minimize_loop = jax.pmap(_minimize_loop, axis_name="i")

    # training_state = TrainingState(optimizer_state=optimizer_state, params=init_params, key=jnp.stack(jax.random.split(key, 1)), normalizer_params=normalizer_params)
    training_state = TrainingState(optimizer_state=optimizer_state, params=init_params, key=key, normalizer_params=normalizer_params)
    state = first_state

    def log_loop(carry, unused_t):
        training_state, state, key_debug = carry
        (training_state, state), losses = minimize_loop(training_state, state)
        eval_state, key_debug = run_eval(eval_first_state, key_debug, training_state.params["policy"], training_state.normalizer_params)
        eval_metrics = eval_state.info["eval_metrics"]
        return (training_state, state, key_debug), (losses, eval_metrics)
        # return (training_state, state, key_debug), eval_metrics

    (training_state, state, key_debug), metrics = jax.lax.scan(log_loop, (training_state, state, key_debug), (), length=log_frequency + 1)

    # inference = make_inference_fn(core_env.observation_size, core_env.action_size, normalize_observations)
    params = normalizer_params, training_state.params["policy"]

    return (params, metrics)


def make_inference_fn(observation_size, action_size, normalize_observations):
    """Creates params and inference function for the PPO agent."""
    _, obs_normalizer_apply_fn = normalization.make_data_and_apply_fn(observation_size, normalize_observations)
    parametric_action_distribution = distribution.NormalTanhDistribution(event_size=action_size)
    policy_model, _ = networks.make_models(parametric_action_distribution.param_size, observation_size)

    def inference_fn(params, obs, key):
        normalizer_params, policy_params = params
        obs = obs_normalizer_apply_fn(normalizer_params, obs)
        action = parametric_action_distribution.sample(policy_model.apply(policy_params, obs), key)
        return action

    return inference_fn
