import jax, flax, optax, distrax
import jax.numpy as jnp
from flax import linen as nn
from typing import Any, Tuple, Sequence, Optional
from jaxrl_m.typing import *
from jax.typing import *
from jaxrl_m.networks import MLP, default_init, ValueCritic, ensemblize
from util_vipo import (
    soft_clip,
)
from flax.struct import dataclass


@dataclass
class NormalizerState:
    mean: Array
    std: Array
    num_points: int


@dataclass(frozen=True)
class Normalizer:
    max_points: jnp.array = jnp.array(1e8, dtype=jnp.int32)

    @staticmethod
    def reset(normalizer_state: NormalizerState) -> NormalizerState:
        return NormalizerState(
            mean=jnp.zeros_like(normalizer_state.mean),
            std=jnp.ones_like(normalizer_state.std),
            num_points=0,
        )

    def update_stats(self, x: Array, normalizer_state: NormalizerState) -> NormalizerState:
        assert len(x.shape) == 2 and x.shape[-1] == normalizer_state.mean.shape[-1]
        num_points = x.shape[0]
        total_points = num_points + normalizer_state.num_points
        mean = (normalizer_state.mean * normalizer_state.num_points + jnp.sum(x, axis=0)) / total_points
        new_s_n = (
            jnp.square(normalizer_state.std) * normalizer_state.num_points
            + jnp.sum(jnp.square(x - mean), axis=0)
            + normalizer_state.num_points * jnp.square(normalizer_state.mean - mean)
        )

        new_var = new_s_n / total_points
        std = jnp.clip(jnp.sqrt(new_var), min=1e-3)
        new_normalizer_state = NormalizerState(
            mean=mean,
            std=std,
            num_points=jnp.minimum(total_points, self.max_points),  # keep at most max number of points to avoid overflow
        )
        return new_normalizer_state

    @staticmethod
    def normalize(x: Array, normalizer_state: NormalizerState):
        return (x - normalizer_state.mean) / normalizer_state.std

    @staticmethod
    def denormalize(norm_x: Array, normalizer_state: NormalizerState):
        return norm_x * normalizer_state.std + normalizer_state.mean

    @staticmethod
    def scale(unscaled_x: Array, normalizer_state: NormalizerState):
        return unscaled_x * normalizer_state.std


# s,a -> s',r
class DynamicsModel(nn.Module):
    hidden_dims: Sequence[int]
    obs_dim: int
    action_dim: int
    reward_dim: int = 1
    dependent_std: bool = True
    final_fc_init_scale: float = 1e-2

    @nn.compact
    def __call__(
        self,
        observations: jnp.ndarray,
        actions: jnp.ndarray,
        log_std_min: jnp.ndarray,  # 外部传入
        log_std_max: jnp.ndarray,  # 外部传入
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # latent outputs
        outputs = MLP(
            self.hidden_dims,
            activate_final=True,
        )(jnp.concatenate([observations, actions], axis=-1))

        # final heads (means and log_stds)
        final_output_dim = self.obs_dim + self.reward_dim
        means = nn.Dense(
            final_output_dim,
            kernel_init=nn.initializers.normal(self.final_fc_init_scale),
        )(outputs)
        if self.dependent_std:
            log_stds = nn.Dense(
                final_output_dim,
                kernel_init=nn.initializers.normal(self.final_fc_init_scale),
            )(outputs)
        else:
            log_stds = self.param("log_stds", nn.initializers.zeros, (final_output_dim,))

        # 使用外部传入的 log_std_min 和 log_std_max 限制 log_stds
        log_stds = soft_clip(log_stds, log_std_min, log_std_max)

        return means, log_stds


# 集成动态模型
class EnsembledDynamics(nn.Module):
    ensemble_size: int
    hidden_dims: Sequence[int]
    obs_dim: int
    action_dim: int
    reward_dim: int = 1
    dependent_std: bool = True
    final_fc_init_scale: float = 1e-2

    def setup(self):
        # 定义共享的 log_std_min 和 log_std_max
        final_output_dim = self.obs_dim + self.reward_dim
        self.log_std_min = self.param(
            "log_std_min",
            nn.initializers.constant(-5.0),  # 初始化为 -5
            (final_output_dim,),
        )
        self.log_std_max_increament = self.param(
            "log_std_max_increament",
            nn.initializers.constant(5.25),  # 初始化为 0.25
            (final_output_dim,),
        )

        # 使用 vmap 创建集成模型
        self.ensembled_dynamics = nn.vmap(
            DynamicsModel,
            variable_axes={"params": 0},  # 每个子模型有独立的参数
            split_rngs={"params": True},
            in_axes=None,
            out_axes=0,
            axis_size=self.ensemble_size,
        )(
            hidden_dims=self.hidden_dims,
            obs_dim=self.obs_dim,
            action_dim=self.action_dim,
            reward_dim=self.reward_dim,
            dependent_std=self.dependent_std,
            final_fc_init_scale=self.final_fc_init_scale,
        )

    def get_log_std_max(self, epsilon: float = 1e-3):
        # 计算 log_std_max
        log_std_max = jax.nn.softplus(self.log_std_max_increament) + self.log_std_min + epsilon
        return log_std_max

    def __call__(
        self,
        observations: jnp.ndarray,
        actions: jnp.ndarray,
    ):
        # 将共享的 log_std_min 和 log_std_max 传递给子模型
        return self.ensembled_dynamics(
            observations,
            actions,
            self.log_std_min,
            self.get_log_std_max(),
        )


class EnsembledValueCritics(nn.Module):
    ensemble_size: int
    hidden_dims: Sequence[int]

    def setup(self):
        self.ensembled_value_critics = ensemblize(
            ValueCritic,
            self.ensemble_size,
            in_axes=0,  # multi-head: multiple inputs from EnsenmbledDynamicsModel
        )(hidden_dims=self.hidden_dims)

    @nn.compact
    def __call__(self, observations: jnp.ndarray):
        # return a list of value critics
        return self.ensembled_value_critics(observations)


def main():
    from pprint import pprint

    vnet = EnsembledValueCritics(ensemble_size=3, hidden_dims=[256, 256])
    example_input = jnp.ones((3, 2, 10))  # ensemble_size, batch_size, input_dim
    out, params = vnet.init_with_output(jax.random.PRNGKey(0), example_input)
    print(out)
    pprint(jax.tree.map(lambda x: x.shape, params))
    import util_vipo

    print(util_vipo.decay_loss(params, 1e-4))


if __name__ == "__main__":
    main()
