"""Implementations of algorithms for continuous control."""

import functools
from jaxrl_m.typing import *

import jax
import jax.numpy as jnp
import numpy as np
import optax
from jaxrl_m.common import TrainState, target_update, nonpytree_field
from jaxrl_m.networks import Policy, Critic, ensemblize
from dynamics_vipo import VIPODynamics
from collections import defaultdict
from util_vipo import (
    sample_from_norm,
    merge_batch,
)
import flax
import flax.linen as nn


class RandomAgent(flax.struct.PyTreeNode):
    rng: PRNGKey
    config: dict = nonpytree_field()

    @jax.jit
    def update(agent, dynamics: VIPODynamics, real_batch: Batch, fake_batch: Batch):
        # TODO: implement your own update logic
        new_rng, _ = jax.random.split(agent.rng, 2)
        return agent.replace(rng=new_rng), {}

    @jax.jit
    def sample_actions(
        agent,
        observations: np.ndarray,
        *,
        seed: PRNGKey,
        temperature: float = 1.0,
    ) -> jnp.ndarray:
        # TODO: use your own sampling method
        if observations.ndim == 1:
            actions = jax.random.normal(seed, shape=(agent.config["action_dim"],))
        else:
            actions = jax.random.normal(seed, shape=(observations.shape[0], agent.config["action_dim"]))
        actions = jnp.clip(actions, -1, 1)
        return actions

    @functools.partial(jax.jit, static_argnames=["rollout_length"])
    def _rollout(
        agent,
        dynamics: VIPODynamics,
        init_obss: jnp.ndarray,  # (batch_size, obs_dim)
        rollout_length: int,
        key: PRNGKey,
    ):
        batch_size = init_obss.shape[0]
        obs_dim = init_obss.shape[1]
        action_dim = agent.config["action_dim"]

        # init arrays
        observations = jnp.zeros((batch_size, rollout_length, obs_dim))
        actions = jnp.zeros((batch_size, rollout_length, action_dim))
        rewards = jnp.zeros((batch_size, rollout_length))
        masks = jnp.zeros((batch_size, rollout_length))
        done_float = jnp.zeros((batch_size, rollout_length))
        next_observations = jnp.zeros((batch_size, rollout_length, obs_dim))
        valid_steps = jnp.zeros((batch_size, rollout_length), dtype=jnp.bool_)

        # init carry
        step = 0
        current_observations = init_obss
        active_mask = jnp.ones(batch_size, dtype=jnp.bool_)
        carry = (
            step,
            key,
            current_observations,
            active_mask,
            observations,
            actions,
            rewards,
            masks,
            done_float,
            next_observations,
            valid_steps,
        )

        # loop condition
        def cond_fun(carry):
            step, _, _, active_mask, _, _, _, _, _, _, _ = carry
            return jnp.logical_and(step < rollout_length, jnp.any(active_mask))

        # loop body
        def body_fun(carry):
            (
                step,
                key,
                current_observations,
                active_mask,
                observations,
                actions,
                rewards,
                masks,
                done_float,
                next_observations,
                valid_steps,
            ) = carry
            key, sample_key, step_key, new_key = jax.random.split(key, 4)

            # 为所有轨迹生成动作
            actions_all = agent.sample_actions(current_observations, seed=sample_key)

            # 计算所有轨迹的下一状态
            next_observations_all, rewards_all, terminals_all, _ = dynamics.step(current_observations, actions_all, step_key)

            # 更新数组为所有轨迹
            observations = observations.at[:, step, :].set(current_observations)
            actions = actions.at[:, step, :].set(actions_all)
            rewards = rewards.at[:, step].set(rewards_all)
            masks = masks.at[:, step].set(1.0 - terminals_all)
            done_float = done_float.at[:, step].set(terminals_all)
            next_observations = next_observations.at[:, step, :].set(next_observations_all)
            valid_steps = valid_steps.at[:, step].set(active_mask)

            # 仅为仍活动的轨迹更新 current_observations
            still_active = jnp.logical_and(active_mask, ~terminals_all)
            current_observations = jnp.where(still_active[:, None], next_observations_all, current_observations)

            # 更新 active_mask：仅保留仍活动的轨迹
            active_mask = still_active

            return (
                step + 1,
                new_key,
                current_observations,
                active_mask,
                observations,
                actions,
                rewards,
                masks,
                done_float,
                next_observations,
                valid_steps,
            )

        # execute loop
        final_carry = jax.lax.while_loop(cond_fun, body_fun, carry)
        return final_carry

    def rollout(
        agent,
        dynamics: VIPODynamics,
        init_obss: jnp.ndarray,  # (batch_size, obs_dim)
        rollout_length: int,
        key: PRNGKey,
    ):
        final_carry = agent._rollout(dynamics, init_obss, rollout_length, key)
        (
            _,
            _,
            _,
            _,
            observations,
            actions,
            rewards,
            masks,
            done_float,
            next_observations,
            valid_steps,
        ) = final_carry

        # 提取有效数据
        valid_indices = jnp.where(valid_steps)
        valid_trajectory_indices, valid_step_indices = valid_indices
        observations_valid = observations[valid_trajectory_indices, valid_step_indices, :]
        actions_valid = actions[valid_trajectory_indices, valid_step_indices, :]
        rewards_valid = rewards[valid_trajectory_indices, valid_step_indices]
        masks_valid = masks[valid_trajectory_indices, valid_step_indices]
        done_float_valid = done_float[valid_trajectory_indices, valid_step_indices]
        next_observations_valid = next_observations[valid_trajectory_indices, valid_step_indices, :]

        # 构建批次字典
        batch = {
            "observations": observations_valid,
            "actions": actions_valid,
            "rewards": rewards_valid,
            "masks": masks_valid,
            "dones_float": done_float_valid,
            "next_observations": next_observations_valid,
        }

        # 计算统计信息
        info = {
            "num_transitions": valid_steps.sum(),
            "reward_mean": rewards_valid.mean(),
        }

        return batch, info


def create_learner(
    key: PRNGKey,
    observations: jnp.ndarray,
    actions: jnp.ndarray,
    # TODO: add other parameters
    **kwargs,
):

    print("Extra kwargs:", kwargs)

    rng, _ = jax.random.split(key)

    # TODO: add your network definitions
    ...

    config = flax.core.FrozenDict(
        dict(
            action_dim=actions.shape[-1],
            # TODO: add other config parameters
        )
    )

    return RandomAgent(
        rng,
        config=config,
    )
