import copy
from functools import partial
from typing import Optional, Tuple, Union

import chex
import distrax
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from absl import flags

from wsrl.common.common import JaxRLTrainState, ModuleDict, nonpytree_field
from wsrl.common.optimizers import make_optimizer
from wsrl.common.typing import Batch, Data, Params, PRNGKey
from wsrl.networks.actor_critic_nets import Critic, Policy, ensemblize
from wsrl.networks.mlp import MLP
#from wsrl.agents.utils import get_srank, get_weight_norm, get_dormant_ratio
from wsrl.agents.td3 import TD3Agent

FLAGS = flags.FLAGS



class ReBRACAgent(TD3Agent):
    """
    ReBRAC (Reweighted Bellman Regression with Critic Regularization) Agent.
    
    Inherits from TD3Agent and adds:
    - Behavioral Cloning (BC) regularization for both actor and critic
    - Q-value normalization for loss balancing
    - Modified critic targets with BC penalties
    - Support for offline learning from fixed datasets
    
    Key ReBRAC features over TD3:
    - BC regularization to stay close to dataset distribution
    - Adaptive Q normalization for loss term balancing
    - Uses min(Q1, Q2) for policy updates instead of just Q1
    - Modified critic targets with BC penalties
    """

    def critic_loss_fn(self, batch, params: Params, rng: PRNGKey):
        """
        ReBRAC critic loss with BC regularization in the target.
        Key difference from TD3: target includes BC penalty.
        """
        batch_size = batch["rewards"].shape[0]
        
        # Get target actions from target policy
        rng, target_policy_rng, noise_rng = jax.random.split(rng, 3)
        target_next_actions = self.forward_target_policy(
            batch["next_observations"], 
            rng=target_policy_rng
        )
        
        # Add target policy smoothing noise (key TD3/ReBRAC feature)
        target_next_actions = self._add_target_policy_noise(
            target_next_actions, noise_rng
        )

        # Evaluate target Q-values using both target critics
        target_next_qs = self.forward_target_critic(
            batch["next_observations"],
            target_next_actions,
            rng=rng,
        )  # (ensemble_size, batch_size)

        # Optional critic subsampling (REDQ-style)
        if self.config.get("critic_subsample_size") is not None:
            rng, subsample_key = jax.random.split(rng)
            # subsample_idcs = jax.random.randint(
            #     subsample_key,
            #     (self.config["critic_subsample_size"],),
            #     0,
            #     self.config["critic_ensemble_size"],
            # )
            subsample_idcs = jax.random.choice(
                subsample_key,
                self.config["critic_ensemble_size"],
                (self.config["critic_subsample_size"],),
                replace=False,
            )
            target_next_qs = target_next_qs[subsample_idcs]

        # Take minimum of twin critics (key TD3/ReBRAC feature)
        target_next_min_q = target_next_qs.min(axis=0)
        chex.assert_shape(target_next_min_q, (batch_size,))

        # ReBRAC: Add BC penalty to critic target
        critic_bc_coef = self.config.get("critic_bc_coef", 0.0)
        if critic_bc_coef > 0 and "next_actions" in batch:
            # BC penalty: ||target_policy(s') - dataset_next_action||²
            bc_penalty = ((target_next_actions - batch["next_actions"]) ** 2).sum(axis=-1)
            target_next_min_q = target_next_min_q - critic_bc_coef * bc_penalty

        # Compute target Q-values
        target_q = (
            batch["rewards"]
            + self.config["discount"] * batch["masks"] * target_next_min_q
        )
        chex.assert_shape(target_q, (batch_size,))

        # Evaluate current Q-values
        predicted_qs = self.forward_critic(
            batch["observations"],
            batch["actions"],
            rng=rng,
            grad_params=params,
        )
        chex.assert_shape(predicted_qs, (self.config["critic_ensemble_size"], batch_size))

        # MSE loss for both critics
        target_qs = target_q[None].repeat(self.config["critic_ensemble_size"], axis=0)  # (2, batch_size)
        chex.assert_equal_shape([predicted_qs, target_qs])
        critic_loss = jnp.mean((predicted_qs - target_qs) ** 2)

        info = {
            "critic_loss": critic_loss,
            "predicted_qs": jnp.mean(predicted_qs),
            "target_qs": jnp.mean(target_q),
            "q1": jnp.mean(predicted_qs[0]),
            "q2": jnp.mean(predicted_qs[1]),
        }

        if critic_bc_coef > 0 and "next_actions" in batch:
            info["critic_bc_penalty"] = jnp.mean(bc_penalty)

        return critic_loss, info

    def policy_loss_fn(self, batch, params: Params, rng: PRNGKey):
        """
        ReBRAC policy loss with BC regularization and Q normalization.
        Key differences from TD3:
        1. Adds BC penalty term
        2. Uses min(Q1, Q2) instead of just Q1
        3. Optional Q normalization for loss balancing
        """
        batch_size = batch["rewards"].shape[0]
        
        rng, policy_rng, critic_rng = jax.random.split(rng, 3)
        
        # Get actions from current policy
        policy_output = self.forward_policy(
            batch["observations"],
            rng=policy_rng,
            grad_params=params,
        )
        
        # Handle the case where policy returns a distribution
        if hasattr(policy_output, 'mode'):
            # For deterministic policies, use mode (mean of the distribution)
            actions = policy_output.mode()
        elif hasattr(policy_output, 'sample'):
            # Fallback to sampling if mode is not available
            actions = policy_output.sample(seed=policy_rng)
        else:
            # Assume it's already actions (legacy case)
            actions = policy_output

        # Evaluate Q-values using current critics
        predicted_qs = self.forward_critic(
            batch["observations"],
            actions,
            rng=critic_rng,
        )
        
        # Use min(Q1, Q2) for policy updates (ReBRAC style, more conservative than TD3)
        predicted_q_min = predicted_qs.min(axis=0)
        chex.assert_shape(predicted_q_min, (batch_size,))

        # ReBRAC: BC penalty
        actor_bc_coef = self.config.get("actor_bc_coef", 0.0)
        bc_penalty = ((actions - batch["actions"]) ** 2).sum(axis=-1)

        # ReBRAC: Q normalization for loss balancing
        lmbda = 1.0
        if self.config.get("normalize_q", False):
            lmbda = jax.lax.stop_gradient(1.0 / jnp.abs(predicted_q_min).mean())

        # ReBRAC actor loss: BC penalty - λ * Q_values
        actor_loss = (actor_bc_coef * bc_penalty - lmbda * predicted_q_min).mean()

        info = {
            "actor_loss": actor_loss,
            "policy_q_min": jnp.mean(predicted_q_min),
            "policy_q1": jnp.mean(predicted_qs[0]),
            "policy_q2": jnp.mean(predicted_qs[1]),
            "bc_penalty": jnp.mean(bc_penalty),
            "actions_mse": ((actions - batch["actions"]) ** 2).sum(axis=-1).mean(),
            "dataset_rewards": batch["rewards"],
            "mc_returns": batch.get("mc_returns", None),
            "actions": actions,
            "lmbda": lmbda,
        }

        return actor_loss, info

    @classmethod
    def _create_common(
        cls,
        rng: PRNGKey,
        observations: Data,
        actions: jnp.ndarray,
        # Models
        actor_def: nn.Module,
        critic_def: nn.Module,
        # Optimizer
        actor_optimizer_kwargs={
            "learning_rate": 1e-3,
        },
        critic_optimizer_kwargs={
            "learning_rate": 1e-3,
        },
        # Algorithm config
        discount: float = 0.99,
        soft_target_update_rate: float = 0.005,
        # ReBRAC-specific parameters
        actor_bc_coef: float = 1.0,
        critic_bc_coef: float = 1.0,
        normalize_q: bool = True,
        target_policy_noise: float = 0.2,
        target_noise_clip: float = 0.5,
        exploration_noise: float = 0.1,
        policy_delay: int = 2,
        critic_ensemble_size: int = 2,
        critic_subsample_size: Optional[int] = None,
        **kwargs,
    ):
        """Common part of both create() methods with ReBRAC-specific parameters."""
        networks = {
            "actor": actor_def,
            "critic": critic_def,
        }

        model_def = ModuleDict(networks)

        # Define optimizers
        txs = {
            "actor": make_optimizer(**actor_optimizer_kwargs),
            "critic": make_optimizer(**critic_optimizer_kwargs),
        }

        rng, init_rng = jax.random.split(rng)
        network_input = observations
        params = model_def.init(
            init_rng,
            actor=[network_input],
            critic=[network_input, actions],
        )["params"]

        rng, create_rng = jax.random.split(rng)
        state = JaxRLTrainState.create(
            apply_fn=model_def.apply,
            params=params,
            txs=txs,
            target_params=params,
            rng=create_rng,
        )

        return cls(
            state=state,
            config=dict(
                discount=discount,
                soft_target_update_rate=soft_target_update_rate,
                # ReBRAC-specific parameters
                actor_bc_coef=actor_bc_coef,
                critic_bc_coef=critic_bc_coef,
                normalize_q=normalize_q,
                # TD3 parameters
                target_policy_noise=target_policy_noise,
                target_noise_clip=target_noise_clip,
                exploration_noise=exploration_noise,
                policy_delay=policy_delay,
                critic_ensemble_size=critic_ensemble_size,
                critic_subsample_size=critic_subsample_size,
                **kwargs,
            ),
        )

    @classmethod
    def create(
        cls,
        rng: PRNGKey,
        observations: Data,
        actions: jnp.ndarray,
        # Model architecture
        encoder_def: nn.Module,
        shared_encoder: bool = True,
        critic_network_kwargs: dict = {
            "hidden_dims": [256, 256, 256],
        },
        policy_network_kwargs: dict = {
            "hidden_dims": [256, 256, 256],
        },
        policy_kwargs: dict = {
            "tanh_squash_distribution": True,
            "std_parameterization": "fixed",
            "fixed_std": 0.0,
        },
        critic_ensemble_size: int = 2,
        critic_subsample_size: Optional[int] = None,
        **kwargs,
    ):
        """
        Create a new ReBRAC agent.
        """
        if shared_encoder:
            encoders = {
                "actor": encoder_def,
                "critic": encoder_def,
            }
        else:
            encoders = {
                "actor": encoder_def,
                "critic": copy.deepcopy(encoder_def),
            }

        # Define deterministic policy network
        policy_def = Policy(
            encoder=encoders["actor"],
            network=MLP(**policy_network_kwargs),
            action_dim=actions.shape[-1],
            **policy_kwargs,
            name="actor",
        )

        # Define twin critic networks (always 2 for ReBRAC)
        critic_backbone = partial(MLP, **critic_network_kwargs)
        critic_backbone = ensemblize(critic_backbone, critic_ensemble_size)(
            name="critic_ensemble"
        )
        critic_def = partial(
            Critic,
            encoder=encoders["critic"],
            network=critic_backbone,
        )(name="critic")

        return cls._create_common(
            rng,
            observations,
            actions,
            actor_def=policy_def,
            critic_def=critic_def,
            critic_subsample_size=critic_subsample_size,
            critic_ensemble_size=critic_ensemble_size,
            **kwargs,
        ) 