from jaxrl_m.dataset import Dataset
import functools
from jaxrl_m.typing import *
import jax
from jax import vmap
from jax.lax import stop_gradient
import jax.numpy as jnp
import numpy as np
import optax
from jaxrl_m.common import TrainState, target_update, nonpytree_field
from jaxrl_m.networks import Policy, Critic, ensemblize
from typing import Any, Tuple, Sequence, Optional
from util_vipo import (
    sample_from_norm,
    get_params_shape,
    get_log_prob,
    get_log_prob_jnp,
    l2_loss,
    msew_loss,
    var_loss,
    nll_loss,
    decay_loss,
)
from model_vipo import (
    EnsembledDynamics,
    ValueCritic,
    EnsembledValueCritics,
    NormalizerState,
    Normalizer,
)
import flax
import flax.linen as nn
from numpy.random import choice
from jax import random


class VIPODynamics(flax.struct.PyTreeNode):
    rng: PRNGKey
    ensembled_dynamics: TrainState
    ensembled_value_critic: TrainState
    ensembled_value_critic_ema: TrainState
    true_value_critic: TrainState
    true_value_critic_ema: TrainState
    termination_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray] = nonpytree_field()
    normalizer: Normalizer = nonpytree_field()
    action_norm_state: NormalizerState = nonpytree_field()
    obs_norm_state: NormalizerState = nonpytree_field()
    ensemble_size: int = nonpytree_field()
    discount: float = nonpytree_field()
    penalty_coeff: float = nonpytree_field()
    phi: float = nonpytree_field()
    config: dict = nonpytree_field()
    uncertainty_mode: str = nonpytree_field()
    logstd_loss_coeff: float = 0.02

    @jax.jit
    def update(dynamics, batch: Batch):
        normed_obs = dynamics.normalizer.normalize(batch["observations"], dynamics.obs_norm_state)
        normed_actions = dynamics.normalizer.normalize(batch["actions"], dynamics.action_norm_state)
        repeated_obs = jnp.repeat(batch["observations"][None, ...], dynamics.ensemble_size, axis=0)
        true_obs_diff = batch["next_observations"] - batch["observations"]
        groundtruths = jnp.concatenate([true_obs_diff, batch["rewards"][..., None]], axis=-1)

        def true_value_critic_loss_fn(true_value_critic_params):
            true_value = dynamics.true_value_critic(batch["observations"], params=true_value_critic_params)
            true_next_value = dynamics.true_value_critic_ema(batch["next_observations"])
            true_value_target = batch["rewards"] + dynamics.discount * true_next_value
            true_v_loss = l2_loss(true_value, true_value_target)
            return true_v_loss, {"true_value_critic_l2_loss": true_v_loss}

        def ensembled_value_critic_loss_fn(ensembled_value_critic_params):
            pred_diff_means, _ = dynamics.ensembled_dynamics(normed_obs, normed_actions)
            pred_next_obs = batch["observations"] + pred_diff_means[..., :-1]
            pred_next_value = dynamics.ensembled_value_critic_ema(pred_next_obs)  #  (ensemble_size, batch_size)
            pred_value_target = pred_diff_means[..., -1] + dynamics.discount * pred_next_value
            pred_value = dynamics.ensembled_value_critic(repeated_obs, params=ensembled_value_critic_params)
            ensembled_v_loss = vmap(l2_loss, in_axes=(0, 0))(pred_value, pred_value_target)  
            return ensembled_v_loss.mean(), {"ensembled_value_critic_l2_loss_mean": ensembled_v_loss.mean()}

        def ensembled_dynamics_loss_fn(ensembled_dynamics_params):
            true_value_stable = dynamics.true_value_critic_ema(batch["observations"])
            pred_value_stable = dynamics.ensembled_value_critic_ema(repeated_obs)
            pred_diff_means, pred_diff_logstds = dynamics.ensembled_dynamics(
                normed_obs,
                normed_actions,
                params=ensembled_dynamics_params,
            )
            pred_diff_stds = jnp.exp(pred_diff_logstds)
            pred_next_obs = batch["observations"] + pred_diff_means[..., :-1]
            pred_next_value = dynamics.ensembled_value_critic_ema(pred_next_obs)  # (ensemble_size, batch_size)
            pred_value_target = pred_diff_means[..., -1] + dynamics.discount * pred_next_value
            log_prob = vmap(get_log_prob, in_axes=(0, 0, 0))(pred_diff_means, pred_diff_means, pred_diff_stds)
            # model prediction loss
            mloss = vmap(msew_loss, in_axes=(0, 0, None))(pred_diff_means, pred_diff_logstds, groundtruths)
            vloss = vmap(var_loss)(pred_diff_logstds)
            nloss = vmap(nll_loss, in_axes=(0, 0, None))(pred_diff_means, pred_diff_stds, groundtruths)
            # ensemble consistency loss
            advantage = stop_gradient((true_value_stable[None, ...] - pred_value_stable) * pred_value_target)
            value_loss = jnp.mean(-advantage * log_prob, axis=-1)
            # parameter decay loss
            dloss = decay_loss(ensembled_dynamics_params)
            # total loss
            tot_loss = mloss.sum() + vloss.sum() + dynamics.phi * value_loss.sum() + dloss
            log_std_min= ensembled_dynamics_params["log_std_min"]
            log_std_max= dynamics.ensembled_dynamics(params=ensembled_dynamics_params,method=EnsembledDynamics.get_log_std_max)
            tot_loss += dynamics.logstd_loss_coeff * (log_std_max.sum() - log_std_min.sum())
            return tot_loss, {
                "ensembled_dynamics_msew_loss": mloss.sum(),
                "ensembled_dynamics_var_loss": vloss.sum(),
                "ensembled_dynamics_nll_loss": nloss.sum(),
                "ensembled_dynamics_value_loss": value_loss.sum(),
                "ensembled_dynamics_decay_loss": dloss,
                "ensembled_dynamics_total_loss": tot_loss,
                "debug/advantage": advantage.mean(),
                "debug/log_prob": log_prob.mean(),
                "debug/log_std_max": log_std_max.sum(),
                "debug/log_std_min": log_std_min.sum(),
                "debug/delta_value": (true_value_stable[None, ...] - pred_value_stable).mean(),
                "debug/pred_value_target": pred_value_target.mean(),
                "debug/true_value_stable": true_value_stable.mean(),
                "debug/pred_value_stable": pred_value_stable.mean(),
            }

        new_true_value_critic, true_value_critic_info = dynamics.true_value_critic.apply_loss_fn(
            loss_fn=true_value_critic_loss_fn,
            has_aux=True,
        )
        new_true_value_critic_ema = target_update(
            dynamics.true_value_critic,
            dynamics.true_value_critic_ema,
            dynamics.config["target_update_rate"],
        )

        new_ensembled_value_critic, ensembled_value_critic_info = dynamics.ensembled_value_critic.apply_loss_fn(
            loss_fn=ensembled_value_critic_loss_fn,
            has_aux=True,
        )
        new_ensembled_value_critic_ema = target_update(
            dynamics.ensembled_value_critic,
            dynamics.ensembled_value_critic_ema,
            dynamics.config["target_update_rate"],
        )

        new_ensembled_dynamics, ensembled_dynamics_info = dynamics.ensembled_dynamics.apply_loss_fn(
            loss_fn=ensembled_dynamics_loss_fn,
            has_aux=True,
        )

        return dynamics.replace(
            true_value_critic=new_true_value_critic,
            true_value_critic_ema=new_true_value_critic_ema,
            ensembled_value_critic=new_ensembled_value_critic,
            ensembled_value_critic_ema=new_ensembled_value_critic_ema,
            ensembled_dynamics=new_ensembled_dynamics,
        ), {**true_value_critic_info, **ensembled_value_critic_info, **ensembled_dynamics_info}

    @jax.jit
    def evaluate(dynamics, batch: Batch):
        obs_diffs = batch["next_observations"] - batch["observations"]
        targets = jnp.concatenate([obs_diffs, batch["rewards"][..., None]], axis=-1)
        normed_obs = dynamics.normalizer.normalize(batch["observations"], dynamics.obs_norm_state)
        normed_actions = dynamics.normalizer.normalize(batch["actions"], dynamics.action_norm_state)
        pred_diff_means, pred_diff_logstds = dynamics.ensembled_dynamics(normed_obs, normed_actions)
        pred_diffs = pred_diff_means.mean(axis=0)
        pred_diff_logstds = pred_diff_logstds.mean(axis=0)
        holdout_l2_loss = l2_loss(pred_diffs, targets).mean()
        holdout_nll_loss = nll_loss(pred_diff_means, pred_diff_logstds, targets).mean()
        return {
            "holdout_l2_loss": holdout_l2_loss,
            "holdout_nll_loss": holdout_nll_loss,
        }

    @jax.jit
    def step(
        dynamics,
        obs: jnp.ndarray,
        action: jnp.ndarray,
        key=PRNGKey,
    ):
        """
        step the dynamics model for a single step
        s,a -> s',r

        """
        normed_obs = dynamics.normalizer.normalize(obs, dynamics.obs_norm_state)
        normed_action = dynamics.normalizer.normalize(action, dynamics.action_norm_state)
        pred_diff_means, pred_diff_logstds = dynamics.ensembled_dynamics(normed_obs, normed_action)
        pred_diff_stds = jnp.exp(pred_diff_logstds)
        pred_diffs = vmap(sample_from_norm, in_axes=(0, 0, None))(pred_diff_means, pred_diff_logstds, key)
        preds = jnp.concatenate([pred_diffs[..., :-1] + obs, pred_diffs[..., -1:]], axis=-1)  # 前 N-1 个元素 + obs  # 最后 1 个元素即 reward
        # 为 batch_size 个样本中的每一个样本采样一个模型
        batch_size = obs.shape[0]
        ensemblize_size = dynamics.ensemble_size
        model_idxs = random.randint(key, (batch_size,), 0, ensemblize_size)
        samples = preds[model_idxs, jnp.arange(0, batch_size), ...]  # batch_size, output_dim

        next_obss = samples[..., :-1]
        rewards = samples[..., -1]
        terminals = dynamics.termination_fn(obs, action, next_obss).squeeze()
        info = dict()
        info["raw_rewards"] = rewards

        if dynamics.penalty_coeff:
            if dynamics.uncertainty_mode == "aleatoric":
                penalty = jnp.max(jnp.linalg.norm(pred_diff_stds, axis=2), axis=0)  # l2 norm
            elif dynamics.uncertainty_mode == "pairwise-diff":
                next_obs_means = pred_diff_means[..., :-1] + obs
                next_obs_mean = jnp.mean(next_obs_means, axis=0)
                diff = next_obs_means - next_obs_mean
                penalty = jnp.max(jnp.linalg.norm(diff, axis=2), axis=0)
            elif dynamics.uncertainty_mode == "ensemble_std":
                next_obs_means = pred_diff_means[..., :-1] + obs
                penalty = jnp.sqrt(next_obs_means.var(0).mean(1))
            else:
                raise ValueError
            rewards = rewards - dynamics.penalty_coeff * penalty
            info["penalty"] = penalty
        return next_obss, rewards, terminals, info


def create_learner(
    key: PRNGKey,
    example_batch: Batch,
    ensemble_size: int,
    hidden_dims: Sequence[int],
    discount: float,
    target_update_rate: float,
    penalty_coeff: float,
    phi: float,
    normalizer: Normalizer,
    action_norm_state: NormalizerState,
    obs_norm_state: NormalizerState,
    termination_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
    uncertainty_mode: str = "aleatoric",
    **kwargs,
) -> VIPODynamics:
    key, true_v_key, ensemble_v_key, dynamics_key = jax.random.split(key, 4)

    obs_dim = example_batch["observations"].shape[-1]
    action_dim = example_batch["actions"].shape[-1]
    reward_dim = example_batch["rewards"].shape[-1]

    #######################################
    # value critics
    #######################################
    true_value_critic_def = ValueCritic(hidden_dims=hidden_dims)
    outputs, params = true_value_critic_def.init_with_output(true_v_key, example_batch["observations"])
    params = params["params"]
    print(outputs.shape)
    print(get_params_shape(params))

    true_value_critic = TrainState.create(true_value_critic_def, params=params, tx=optax.adam(3e-5))  # TODO: change tx
    true_value_critic_ema = TrainState.create(true_value_critic_def, params=params, tx=None)

    ensembled_value_critic_def = EnsembledValueCritics(ensemble_size=ensemble_size, hidden_dims=hidden_dims)
    repeated_obs = jnp.repeat(example_batch["observations"][None, ...], ensemble_size, axis=0)
    outputs, params = ensembled_value_critic_def.init_with_output(ensemble_v_key, repeated_obs)
    params = params["params"]
    print(outputs.shape)
    print(get_params_shape(params))

    ensembled_value_critic = TrainState.create(ensembled_value_critic_def, params=params, tx=optax.adam(3e-5))  # TODO: change tx
    ensembled_value_critic_ema = TrainState.create(ensembled_value_critic_def, params=params, tx=None)

    #######################################
    # dynamics model
    #######################################
    ensembled_dynamics_def = EnsembledDynamics(
        ensemble_size=ensemble_size,
        hidden_dims=hidden_dims,
        obs_dim=obs_dim,
        action_dim=action_dim,
        reward_dim=reward_dim,
    )
    outputs, params = ensembled_dynamics_def.init_with_output(
        dynamics_key,
        example_batch["observations"],
        example_batch["actions"],
    )  # TODO: change tx
    params = params["params"]
    print(outputs[0].shape, outputs[1].shape)
    print(get_params_shape(params))

    ensembled_dynamics = TrainState.create(ensembled_dynamics_def, params=params, tx=optax.adam(3e-5))

    #######################################
    # config
    #######################################
    config = flax.core.FrozenDict(
        dict(
            discount=discount,
            target_update_rate=target_update_rate,
            penalty_coeff=penalty_coeff,
            phi=phi,
        )
    )

    return VIPODynamics(
        rng=key,
        normalizer=normalizer,
        action_norm_state=action_norm_state,
        obs_norm_state=obs_norm_state,
        ensemble_size=ensemble_size,
        termination_fn=termination_fn,
        ensembled_dynamics=ensembled_dynamics,
        ensembled_value_critic=ensembled_value_critic,
        ensembled_value_critic_ema=ensembled_value_critic_ema,
        true_value_critic=true_value_critic,
        true_value_critic_ema=true_value_critic_ema,
        penalty_coeff=0.0,
        phi=phi,
        discount=discount,
        uncertainty_mode=uncertainty_mode,
        config=config,
    )
