import copy
from functools import partial
from typing import Optional, Tuple, Union

import chex
import distrax
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from absl import flags

from wsrl.common.common import JaxRLTrainState, ModuleDict, nonpytree_field
from wsrl.common.optimizers import make_optimizer
from wsrl.common.typing import Batch, Data, Params, PRNGKey
from wsrl.networks.actor_critic_nets import Critic, Policy, ensemblize
from wsrl.networks.mlp import MLP
from wsrl.agents.utils import get_srank, get_weight_norm, get_dormant_ratio

FLAGS = flags.FLAGS
import optax
from wsrl.utils.train_utils import subsample_batch


class TD3Agent(flax.struct.PyTreeNode):
    """
    Twin Delayed Deep Deterministic Policy Gradient (TD3) Agent.
    
    Key TD3 features:
    - Twin critic networks (always uses exactly 2 critics)
    - Target policy smoothing (adds noise to target actions)
    - Delayed policy updates (updates actor less frequently than critic)
    - Deterministic policy with exploration noise during training
    """

    state: JaxRLTrainState
    config: dict = nonpytree_field()

    def forward_critic(
        self,
        observations: Union[Data, Tuple[Data, Data]],
        actions: jax.Array,
        rng: Optional[PRNGKey] = None,
        *,
        grad_params: Optional[Params] = None,
        train: bool = True,
        get_intermediates: bool = False,
    ) -> jax.Array:
        """
        Forward pass for twin critic networks.
        Pass grad_params to use non-default parameters (e.g. for gradients).
        """
        if train:
            assert rng is not None, "Must specify rng when training"
            
        if get_intermediates:
            q, intermediate = self.state.apply_fn(
                {"params": grad_params or self.state.params},
                observations,
                actions,
                name="critic",
                rngs={"dropout": rng} if train else {},
                train=train,
                capture_intermediates=True
            )  # (2, batch_size) - always 2 critics for TD3
            return q, intermediate
        else:
            q = self.state.apply_fn(
                {"params": grad_params or self.state.params},
                observations,
                actions,
                name="critic",
                rngs={"dropout": rng} if train else {},
                train=train,
                capture_intermediates=False
            )  # (2, batch_size) - always 2 critics for TD3
            return q

    def forward_target_critic(
        self,
        observations: Union[Data, Tuple[Data, Data]],
        actions: jax.Array,
        rng: PRNGKey,
    ) -> jax.Array:
        """
        Forward pass for target critic networks.
        """
        return self.forward_critic(
            observations, actions, rng=rng, grad_params=self.state.target_params
        )

    def forward_policy(
        self,
        observations: Union[Data, Tuple[Data, Data]],
        rng: Optional[PRNGKey] = None,
        *,
        grad_params: Optional[Params] = None,
        train: bool = True,
        get_intermediates: bool = False,
    ) -> Union[jax.Array, Tuple[jax.Array, dict]]:
        """
        Forward pass for deterministic policy network.
        Returns actions directly (not a distribution like SAC).
        """
        if train:
            assert rng is not None, "Must specify rng when training"
            
        if get_intermediates:
            actions, intermediate = self.state.apply_fn(
                {"params": grad_params or self.state.params},
                observations,
                name="actor",
                rngs={"dropout": rng} if train else {},
                train=train,
                capture_intermediates=True
            )
            return actions, intermediate
        else:
            actions = self.state.apply_fn(
                {"params": grad_params or self.state.params},
                observations,
                name="actor",
                rngs={"dropout": rng} if train else {},
                train=train,
                capture_intermediates=False
            )
            return actions

    def forward_target_policy(
        self,
        observations: Union[Data, Tuple[Data, Data]],
        rng: PRNGKey,
    ) -> jax.Array:
        """
        Forward pass for target policy network.
        """
        policy_output = self.forward_policy(
            observations, rng=rng, grad_params=self.state.target_params
        )
        
        # Handle the case where policy returns a distribution
        if hasattr(policy_output, 'mode'):
            # For deterministic policies, use mode (mean of the distribution)
            return policy_output.mode()
        elif hasattr(policy_output, 'sample'):
            # Fallback to sampling if mode is not available
            return policy_output.sample(seed=rng)
        else:
            # Assume it's already actions (legacy case)
            return policy_output

    @jax.jit
    def forward_value(
        self,
        observations: Union[Data, Tuple[Data, Data]],
        *,
        train: bool = False,
    ) -> jax.Array:
        """
        Get the state-value function using the deterministic policy.
        This is never needed in training, only for evaluation.
        """
        policy_output = self.forward_policy(observations, train=False)
        
        # Handle the case where policy returns a distribution
        if hasattr(policy_output, 'mode'):
            # For deterministic policies, use mode (mean of the distribution)
            actions = policy_output.mode()
        elif hasattr(policy_output, 'sample'):
            # Fallback to sampling if mode is not available
            actions = policy_output.sample(seed=jax.random.PRNGKey(0))
        else:
            # Assume it's already actions (legacy case)
            actions = policy_output
            
        q = self.forward_critic(observations, actions, train=False)
        # Take minimum of twin critics
        q = q.min(axis=0)
        return q

    def _add_target_policy_noise(self, actions: jax.Array, rng: PRNGKey) -> jax.Array:
        """
        Add target policy smoothing noise to actions.
        This is a key component of TD3.
        """
        noise = jax.random.normal(rng, actions.shape) * self.config["target_policy_noise"]
        noise = jnp.clip(
            noise, 
            -self.config["target_noise_clip"], 
            self.config["target_noise_clip"]
        )
        noisy_actions = actions + noise
        # Clip to action bounds (assuming [-1, 1])
        return jnp.clip(noisy_actions, -1.0, 1.0)

    def critic_loss_fn(self, batch, params: Params, rng: PRNGKey):
        """
        TD3 critic loss with target policy smoothing.
        """
        batch_size = batch["rewards"].shape[0]
        
        # Get target actions from target policy
        rng, target_policy_rng, noise_rng = jax.random.split(rng, 3)
        target_next_actions = self.forward_target_policy(
            batch["next_observations"], 
            rng=target_policy_rng
        )
        
        # Add target policy smoothing noise (key TD3 feature)
        target_next_actions = self._add_target_policy_noise(
            target_next_actions, noise_rng
        )

        # Evaluate target Q-values using both target critics
        target_next_qs = self.forward_target_critic(
            batch["next_observations"],
            target_next_actions,
            rng=rng,
        )  # (ensemble_size, batch_size)

        # Optional critic subsampling (REDQ-style)
        if self.config.get("critic_subsample_size") is not None:
            rng, subsample_key = jax.random.split(rng)
            # subsample_idcs = jax.random.randint(
            #     subsample_key,
            #     (self.config["critic_subsample_size"],),
            #     0,
            #     self.config["critic_ensemble_size"],
            # )
            subsample_idcs = jax.random.choice(
                subsample_key,
                self.config["critic_ensemble_size"],
                (self.config["critic_subsample_size"],),
                replace=False,
            )
            target_next_qs = target_next_qs[subsample_idcs]

        # Take minimum of twin critics (key TD3 feature)
        target_next_min_q = target_next_qs.min(axis=0)
        chex.assert_shape(target_next_min_q, (batch_size,))

        # Compute target Q-values
        target_q = (
            batch["rewards"]
            + self.config["discount"] * batch["masks"] * target_next_min_q
        )
        chex.assert_shape(target_q, (batch_size,))

        # Evaluate current Q-values
        predicted_qs = self.forward_critic(
            batch["observations"],
            batch["actions"],
            rng=rng,
            grad_params=params,
        )
        chex.assert_shape(predicted_qs, (self.config["critic_ensemble_size"], batch_size))

        # MSE loss for both critics
        target_qs = target_q[None].repeat(self.config["critic_ensemble_size"], axis=0)  # (2, batch_size)
        chex.assert_equal_shape([predicted_qs, target_qs])
        critic_loss = jnp.mean((predicted_qs - target_qs) ** 2)

        info = {
            "critic_loss": critic_loss,
            "predicted_qs": jnp.mean(predicted_qs),
            "target_qs": jnp.mean(target_q),
            "q1": jnp.mean(predicted_qs[0]),
            "q2": jnp.mean(predicted_qs[1]),
        }

        return critic_loss, info

    def policy_loss_fn(self, batch, params: Params, rng: PRNGKey):
        """
        TD3 policy loss with min(Q1, Q2)
        """
        batch_size = batch["rewards"].shape[0]
        
        rng, policy_rng, critic_rng = jax.random.split(rng, 3)
        
        # Get actions from current policy
        policy_output = self.forward_policy(
            batch["observations"],
            rng=policy_rng,
            grad_params=params,
        )
        
        # Handle the case where policy returns a distribution
        if hasattr(policy_output, 'mode'):
            # For deterministic policies, use mode (mean of the distribution)
            actions = policy_output.mode()
        elif hasattr(policy_output, 'sample'):
            # Fallback to sampling if mode is not available
            actions = policy_output.sample(seed=policy_rng)
        else:
            # Assume it's already actions (legacy case)
            actions = policy_output

        # Evaluate Q-values using current critics
        predicted_qs = self.forward_critic(
            batch["observations"],
            actions,
            rng=critic_rng,
        )
        # different from original TD3, we use the minimum of the twin critics
        # Conservative actor update: use the minimum of the twin critics
        predicted_q_min = predicted_qs.min(axis=0)
        chex.assert_shape(predicted_q_min, (batch_size,))

        # (TD3+BC): λ = α / |Q|̄
        if self.config.get("normalize_q", False):
            lmbda = jax.lax.stop_gradient(self.config.get("q_lambda_alpha", 2.5) / jnp.abs(predicted_q_min).mean())
        else:
            lmbda = 1.0

        # Actor Q-loss
        actor_q_loss = -lmbda * jnp.mean(predicted_q_min)

        actor_loss = actor_q_loss

        # Optional BC term
        if self.config.get("use_bc", False):
            bc_loss = ((actions - batch["actions"]) ** 2).mean()
            actor_loss = actor_loss + bc_loss
        else:
            bc_loss = jnp.array(0.0)

        info = {
            "actor_loss": actor_loss,
            "actor_q_loss": actor_q_loss,
            "bc_loss": bc_loss,
            "lmbda": lmbda,
            "policy_q_min": jnp.mean(predicted_q_min),
            "policy_q1": jnp.mean(predicted_qs[0]),
            "policy_q2": jnp.mean(predicted_qs[1]),
            "actions_mse": ((actions - batch["actions"]) ** 2).sum(axis=-1).mean(),
            "dataset_rewards": batch["rewards"],
            "mc_returns": batch.get("mc_returns", None),
            "actions": actions,
        }

        return actor_loss, info

    def loss_fns(self, batch):
        return {
            "critic": partial(self.critic_loss_fn, batch),
            "actor": partial(self.policy_loss_fn, batch),
        }

    @partial(jax.jit, static_argnames=("pmap_axis", "networks_to_update"))
    def update(
        self,
        batch: Batch,
        *,
        pmap_axis: str = None,
        networks_to_update: frozenset[str] = frozenset({"actor", "critic"}),
    ) -> Tuple["TD3Agent", dict]:
        """
        Take one gradient step on specified networks.
        TD3 typically updates critic more frequently than actor.

        Parameters:
            batch: Batch of data to use for the update. Should have keys:
                "observations", "actions", "next_observations", "rewards", "masks".
            pmap_axis: Axis to use for pmap (if None, no pmap is used).
            networks_to_update: Names of networks to update (default: both networks).
                For TD3, common pattern is to update critic every step and actor every
                few steps (controlled by policy_delay parameter).
        Returns:
            Tuple of (new agent, info dict).
        """
        batch_size = batch["rewards"].shape[0]
        chex.assert_tree_shape_prefix(batch, (batch_size,))

        rng, key = jax.random.split(self.state.rng)

        # Compute gradients and update params
        loss_fns = self.loss_fns(batch)

        # Only compute gradients for specified networks
        assert networks_to_update.issubset(
            loss_fns.keys()
        ), f"Invalid networks to update: {networks_to_update}"
        for key in loss_fns.keys() - networks_to_update:
            loss_fns[key] = lambda params, rng: (0.0, {})

        new_state, info = self.state.apply_loss_fns(
            loss_fns, pmap_axis=pmap_axis, has_aux=True
        )

        # Update target networks (if critic was updated)
        if "critic" in networks_to_update:
            new_state = new_state.target_update(self.config["soft_target_update_rate"])

        # Update RNG
        new_state = new_state.replace(rng=rng)

        # Log learning rates
        for name, opt_state in new_state.opt_states.items():
            if (
                hasattr(opt_state, "hyperparams")
                and "learning_rate" in opt_state.hyperparams.keys()
            ):
                info[f"{name}_lr"] = opt_state.hyperparams["learning_rate"]

        return self.replace(state=new_state), info

    @partial(jax.jit, static_argnames=("add_noise",))
    def sample_actions(
        self,
        observations: Data,
        *,
        seed: Optional[PRNGKey] = None,
        add_noise: bool = True,
        **kwargs,
    ) -> jnp.ndarray:
        """
        Sample actions from the deterministic policy.
        For exploration during training, add Gaussian noise.
        """
        policy_output = self.forward_policy(observations, rng=seed, train=False)
        
        # Handle the case where policy returns a distribution
        if hasattr(policy_output, 'mode'):
            # For deterministic policies, use mode (mean of the distribution)
            actions = policy_output.mode()
        elif hasattr(policy_output, 'sample'):
            # Fallback to sampling if mode is not available
            rng, sample_rng = jax.random.split(seed) if seed is not None else (None, None)
            actions = policy_output.sample(seed=sample_rng)
        else:
            # Assume it's already actions (legacy case)
            actions = policy_output
        
        if add_noise and seed is not None:
            noise = jax.random.normal(seed, actions.shape) * self.config["exploration_noise"]
            actions = actions + noise
            # Clip to action bounds
            actions = jnp.clip(actions, -1.0, 1.0)
            
        return actions

    @jax.jit
    def get_debug_metrics(self, batch, **kwargs):
        rng, critic_rng, actor_rng = jax.random.split(self.state.rng, 3)
        critic_loss, critic_info = self.critic_loss_fn(
            batch, self.state.params, critic_rng
        )
        policy_loss, policy_info = self.policy_loss_fn(
            batch, self.state.params, actor_rng
        )

        metrics = {**critic_info, **policy_info}
        return metrics

    def update_config(self, new_config):
        """Update the frozen self.config"""
        object.__setattr__(self, "config", self.config.copy(new_config))

    @classmethod
    def _create_common(
        cls,
        rng: PRNGKey,
        observations: Data,
        actions: jnp.ndarray,
        # Models
        actor_def: nn.Module,
        critic_def: nn.Module,
        # Optimizer
        actor_optimizer_kwargs={
            "learning_rate": 3e-4,
        },
        critic_optimizer_kwargs={
            "learning_rate": 3e-4,
        },
        # Algorithm config
        discount: float = 0.99,
        soft_target_update_rate: float = 0.005,
        # TD3-specific parameters
        target_policy_noise: float = 0.2,
        target_noise_clip: float = 0.5,
        exploration_noise: float = 0.1,
        policy_delay: int = 2,
        normalize_q: bool = False,
        q_lambda_alpha: float = 1.0,
        # Behavioral-cloning toggle
        use_bc: bool = False,
        critic_ensemble_size: int = 2,
        critic_subsample_size: Optional[int] = None,
        **kwargs,
    ):
        """Common part of both create() methods."""
        networks = {
            "actor": actor_def,
            "critic": critic_def,
        }

        model_def = ModuleDict(networks)

        # Define optimizers
        txs = {
            "actor": make_optimizer(**actor_optimizer_kwargs),
            "critic": make_optimizer(**critic_optimizer_kwargs),
        }

        rng, init_rng = jax.random.split(rng)
        network_input = observations
        params = model_def.init(
            init_rng,
            actor=[network_input],
            critic=[network_input, actions],
        )["params"]

        rng, create_rng = jax.random.split(rng)
        state = JaxRLTrainState.create(
            apply_fn=model_def.apply,
            params=params,
            txs=txs,
            target_params=params,
            rng=create_rng,
        )

        return cls(
            state=state,
            config=dict(
                discount=discount,
                soft_target_update_rate=soft_target_update_rate,
                target_policy_noise=target_policy_noise,
                target_noise_clip=target_noise_clip,
                exploration_noise=exploration_noise,
                policy_delay=policy_delay,
                normalize_q=normalize_q,
                q_lambda_alpha=q_lambda_alpha,
                use_bc=use_bc,
                critic_ensemble_size=critic_ensemble_size,
                critic_subsample_size=critic_subsample_size,
                **kwargs,
            ),
        )

    @classmethod
    def create(
        cls,
        rng: PRNGKey,
        observations: Data,
        actions: jnp.ndarray,
        # Model architecture
        encoder_def: nn.Module,
        shared_encoder: bool = True,
        critic_network_kwargs: dict = {
            "hidden_dims": [256, 256, 256],
        },
        policy_network_kwargs: dict = {
            "hidden_dims": [256, 256, 256],
        },
        policy_kwargs: dict = {
            "tanh_squash_distribution": True,
            "std_parameterization": "fixed",
            "fixed_std": 0.0,
        },
        normalize_q: bool = False,
        q_lambda_alpha: float = 1.0,
        use_bc: bool = False,
        critic_ensemble_size: int = 2,
        critic_subsample_size: Optional[int] = None,
        **kwargs,
    ):
        """
        Create a new TD3 agent.
        """
        if shared_encoder:
            encoders = {
                "actor": encoder_def,
                "critic": encoder_def,
            }
        else:
            encoders = {
                "actor": encoder_def,
                "critic": copy.deepcopy(encoder_def),
            }

        # Define deterministic policy network
        policy_def = Policy(
            encoder=encoders["actor"],
            network=MLP(**policy_network_kwargs),
            action_dim=actions.shape[-1],
            **policy_kwargs,
            name="actor",
        )

        # Define twin critic networks (always 2 for TD3)
        critic_backbone = partial(MLP, **critic_network_kwargs)
        critic_backbone = ensemblize(critic_backbone, critic_ensemble_size)(
            name="critic_ensemble"
        )
        critic_def = partial(
            Critic,
            encoder=encoders["critic"],
            network=critic_backbone,
        )(name="critic")

        return cls._create_common(
            rng,
            observations,
            actions,
            actor_def=policy_def,
            critic_def=critic_def,
            normalize_q=normalize_q,
            q_lambda_alpha=q_lambda_alpha,
            use_bc=use_bc,
            critic_subsample_size=critic_subsample_size,
            critic_ensemble_size=critic_ensemble_size,
            **kwargs,
        )

    def update_with_policy_delay(
        self,
        batch: Batch,
        step: int,
        *,
        policy_delay: Optional[int] = None,
        pmap_axis: Optional[str] = None,
    ) -> Tuple["TD3Agent", dict]:
        policy_delay = policy_delay or self.config["policy_delay"]

        # decide whether to update the actor this step
        update_actor = (step % policy_delay) == 0

        networks_to_update = (
            frozenset({"actor", "critic"}) if update_actor else frozenset({"critic"})
        )

        return self.update(
            batch,
            pmap_axis=pmap_axis,
            networks_to_update=networks_to_update,
        )

    def get_critic_metrics(self, intermediate):
        """Compute critic network metrics."""
        metrics = {}
        metrics["critic_dormant_ratio"] = get_dormant_ratio(
            intermediate['intermediates']['modules_critic']['network'],
            prefix="critic",
            tau=0.1
        )["critic/dormant_total"]

        # Find the Dense layer with the largest index
        dense_layers = [k for k in intermediate['intermediates']['modules_critic']['network'].keys() if 'Dense_' in k]
        if not dense_layers:
            raise ValueError("No Dense layers found in critic network")
        max_dense_idx = max(int(layer.split('_')[1]) for layer in dense_layers)
        max_ln_layer = f'LayerNorm_{max_dense_idx}'
        
        metrics["critic_srank"] = get_srank(
            intermediate['intermediates']['modules_critic']['network'][max_ln_layer]['__call__'],
            thershold=0.01
        )

        return metrics

    def get_actor_metrics(self, intermediate):
        """Compute actor network metrics."""
        metrics = {}
        metrics["actor_dormant_ratio"] = get_dormant_ratio(
            intermediate['intermediates']['modules_actor']['network'],
            prefix="actor",
            tau=0.1
        )["actor/dormant_total"]
        
        dense_layers = [k for k in intermediate['intermediates']['modules_actor']['network'].keys() if 'Dense_' in k]
        if not dense_layers:
            raise ValueError("No Dense layers found in actor network")
        max_dense_idx = max(int(layer.split('_')[1]) for layer in dense_layers)
        max_ln_layer = f'LayerNorm_{max_dense_idx}'
        
        metrics["actor_srank"] = get_srank(
            intermediate['intermediates']['modules_actor']['network'][max_ln_layer]['__call__'],
            thershold=0.01
        )
        return metrics

    @partial(jax.jit, static_argnames=("pmap_axis", "networks_to_update"))
    def get_metrics(
        self,
        batch: Batch,
        *,
        pmap_axis: str = None,
        networks_to_update: frozenset[str] = frozenset({"actor", "critic"}),
    ) -> dict:
        """
        Compute metrics for the current state of the agent without updating parameters.
        """
        metrics = {}
        batch_size = batch["rewards"].shape[0]
        chex.assert_tree_shape_prefix(batch, (batch_size,))

        rng, key = jax.random.split(self.state.rng)

        # Compute critic metrics
        _, critic_intermediate = self.forward_critic(
            batch["observations"],
            batch["actions"],
            rng=key,
            grad_params=self.state.params,
            get_intermediates=True,
            train=False
        )
        critic_metrics = self.get_critic_metrics(critic_intermediate)
        metrics.update(critic_metrics)

        # Compute policy metrics
        rng, policy_rng = jax.random.split(rng)
        _, actor_intermediate = self.forward_policy(
            batch["observations"],
            rng=policy_rng,
            grad_params=self.state.params,
            train=False,
            get_intermediates=True,
        )
        #actor_metrics = self.get_actor_metrics(actor_intermediate)
        #metrics.update(actor_metrics)

        return metrics