import functools
from typing import Optional, Sequence, Tuple, Any

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

from common import MLP, Params, PRNGKey, default_init

LOG_STD_MIN = -10.0
LOG_STD_MAX = 2.0


class NormalTanhPolicy(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    state_dependent_std: bool = True
    dropout_rate: Optional[float] = None
    log_std_scale: float = 1.0
    log_std_min: Optional[float] = None
    log_std_max: Optional[float] = None
    tanh_squash_distribution: bool = True

    @nn.compact
    def __call__(self,
                 observations: jnp.ndarray,
                 temperature: float = 1.0,
                 training: bool = False) -> tfd.Distribution:
        outputs = MLP(self.hidden_dims,
                      activate_final=True,
                      dropout_rate=self.dropout_rate)(observations,
                                                      training=training)

        means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)

        if self.state_dependent_std:
            log_stds = nn.Dense(self.action_dim,
                                kernel_init=default_init(
                                    self.log_std_scale))(outputs)
        else:
            log_stds = self.param('log_stds', nn.initializers.zeros,
                                  (self.action_dim, ))

        log_std_min = self.log_std_min or LOG_STD_MIN
        log_std_max = self.log_std_max or LOG_STD_MAX
        log_stds = jnp.clip(log_stds, log_std_min, log_std_max)

        if not self.tanh_squash_distribution:
            means = nn.tanh(means)

        base_dist = tfd.MultivariateNormalDiag(loc=means,
                                               scale_diag=jnp.exp(log_stds) *
                                               temperature)
        if self.tanh_squash_distribution:
            return tfd.TransformedDistribution(distribution=base_dist,
                                               bijector=tfb.Tanh())
        else:
            return base_dist

# Define a custom distribution wrapper for deterministic policies
class DeterministicDistribution:
    """A custom distribution wrapper that always returns the same actions.
    
    This implements the minimal interface needed for compatibility with code
    that expects a TFP distribution.
    """
    def __init__(self, actions: jnp.ndarray):
        self.actions = actions
        # Store the shape for compatibility with TFP distribution interfaces
        self.batch_shape = actions.shape[:-1]
        self.event_shape = actions.shape[-1:]
    
    def sample(self, seed: Optional[Any] = None, sample_shape: Any = ()) -> jnp.ndarray:
        """Always returns the same actions regardless of seed.
        
        Args:
            seed: Ignored, included for API compatibility.
            sample_shape: Ignored, included for API compatibility.
            
        Returns:
            The deterministic actions.
        """
        # Ignore seed and sample_shape, always return the same actions
        return self.actions
    
    def log_prob(self, actions: jnp.ndarray) -> jnp.ndarray:
        """Returns 0 when actions match exactly, -inf otherwise.
        
        Args:
            actions: Actions to compute log probability for.
            
        Returns:
            Log probability values (0 or -inf).
        """
        # For a deterministic distribution, log_prob is 0 when actions match, -inf otherwise
        return jnp.where(
            jnp.all(jnp.isclose(actions, self.actions, rtol=1e-5, atol=1e-5), axis=-1),
            0.0,
            -jnp.inf
        )
    
    def entropy(self) -> jnp.ndarray:
        """Returns entropy, which is 0 for a deterministic distribution.
        
        Returns:
            Entropy values (always 0).
        """
        # Entropy of a deterministic distribution is 0
        return jnp.zeros(self.batch_shape)
    
    def mean(self) -> jnp.ndarray:
        """Returns the mean, which is the same as the actions.
        
        Returns:
            Mean values.
        """
        return self.actions
    
    def mode(self) -> jnp.ndarray:
        """Returns the mode, which is the same as the actions.
        
        Returns:
            Mode values.
        """
        return self.actions


class DetPolicy(nn.Module):
    """A deterministic policy that returns a custom distribution wrapper.
    
    Attributes:
        hidden_dims: Sequence of hidden layer dimensions.
        action_dim: Dimension of the action space.
        dropout_rate: Optional dropout rate for regularization.
        state_dependent_std: Unused, kept for API consistency.
        log_std_scale: Unused, kept for API consistency.
        log_std_min: Unused, kept for API consistency.
        log_std_max: Unused, kept for API consistency.
        tanh_squash_distribution: Unused, kept for API consistency.
    """
    hidden_dims: Sequence[int]
    action_dim: int
    dropout_rate: Optional[float] = None
    # Keep these parameters for API consistency, though they won't be used
    state_dependent_std: bool = False
    log_std_scale: float = 1.0
    log_std_min: Optional[float] = None
    log_std_max: Optional[float] = None
    tanh_squash_distribution: bool = False

    @nn.compact
    def __call__(self,
                 observations: jnp.ndarray,
                 temperature: float = 1.0,
                 training: bool = False) -> DeterministicDistribution:
        """Forward pass to compute deterministic actions.
        
        Args:
            observations: Batch of observations.
            temperature: Unused parameter, kept for API consistency.
            training: Whether in training mode (affects dropout).
            
        Returns:
            A DeterministicDistribution instance.
        """
        from common import MLP, default_init
        
        outputs = MLP(self.hidden_dims,
                      activate_final=True,
                      dropout_rate=self.dropout_rate)(observations,
                                                      training=training)
        
        # Get deterministic actions
        actions = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
        actions = nn.tanh(actions)  # Apply tanh directly to constrain to [-1, 1]
        
        # Return our custom deterministic distribution
        return DeterministicDistribution(actions)


@functools.partial(jax.jit, static_argnames=('actor_def'))
def _sample_actions(rng: PRNGKey,
                    actor_def: nn.Module,
                    actor_params: Params,
                    observations: np.ndarray,
                    temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]:
    dist = actor_def.apply({'params': actor_params}, observations, temperature)
    rng, key = jax.random.split(rng)
    return rng, dist.sample(seed=key)


def sample_actions(rng: PRNGKey,
                   actor_def: nn.Module,
                   actor_params: Params,
                   observations: np.ndarray,
                   temperature: float = 1.0) -> Tuple[PRNGKey, jnp.ndarray]:
    return _sample_actions(rng, actor_def, actor_params, observations,
                           temperature)
