from typing import Sequence, Tuple, Callable, Optional
import distrax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np

from jax_rl.networks.common import MLP, Parameter, default_init
from jax_rl.networks.rpp_emlp_parts import HeadlessEMLP,parse_rep
from rpp.flax import Sequential,Linear


LOG_STD_MIN = -10.0
LOG_STD_MAX = 2.0
STD_MIN = np.exp(LOG_STD_MIN)
STD_MAX = np.exp(LOG_STD_MAX)

EPS = 1e-6

""" MLPs for policies """

# def std_bias_init(key, shape, dtype=jnp.float32):
#     sigma0 = 1.0
#     return jnp.full(shape, jnp.log(jnp.expm1(sigma0)), dtype)

def MultiheadPEMLP(state_rep, action_rep, action_std_rep, G, ch: Sequence[int],
               state_transform, inv_action_transform,
               state_dependent_std=True, small_init=True):
    assert state_dependent_std, "only supporting one option for now"
    state_rep = state_rep(G)
    action_rep = action_rep(G)
    action_std_rep = action_std_rep(G) if action_std_rep is not None else action_rep
    body_emlp = HeadlessEMLP(state_rep, G, ch)
    final_rep = parse_rep(ch, G, len(ch))[-1]
    mean_head = Linear(final_rep, action_rep, init_scale=0.01 if small_init else 1.0)
    std_head  = Linear(final_rep, action_std_rep, init_scale=0.01 if small_init else 1.0)
    # std_head  = Linear(final_rep, action_std_rep)
    return _MultiheadPEMLP(body_emlp, mean_head, std_head, state_transform, inv_action_transform)

class _MultiheadPEMLP(nn.Module):
    body_emlp: nn.Module
    mean_head: nn.Module
    std_head: nn.Module
    state_transform: callable
    inv_action_transform: callable
    @nn.compact
    def __call__(self, observations: jnp.ndarray, temperature: float = 1.0):
        """Returns (mean, std) for a Normal (pre-Tanh)."""
        features = self.body_emlp(self.state_transform(observations))
        means = self.inv_action_transform(self.mean_head(features))
        log_stds = self.inv_action_transform(self.std_head(features))
        # log_stds += self.param("std_basic_bias",
        #                        std_bias_init,
        #                        (log_stds.shape[-1],))
        stds = jnp.clip(jax.nn.softplus(log_stds), STD_MIN, STD_MAX)
        return means, stds
    
class MultiheadPMLP(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    state_dependent_std: bool = True
    small_init: bool = True
    @nn.compact
    def __call__(self, observations: jnp.ndarray, temperature: float = 1.0):
        """Returns (mean, std) for a Normal (pre-Tanh)."""
        outputs = MLP(self.hidden_dims, activate_final=True)(observations)
        scaling = 0.01 if self.small_init else 1.0
        small_init = lambda *args, **kwargs: default_init()(*args, **kwargs) * scaling
        means = nn.Dense(self.action_dim, kernel_init=small_init)(outputs)
        
        if self.state_dependent_std:
            log_stds = nn.Dense(self.action_dim, kernel_init=small_init)(outputs)
            # log_stds = nn.Dense(self.action_dim, kernel_init=default_init(),
            #                     bias_init=std_bias_init)(outputs)
        else:
            log_stds = Parameter(shape=(self.action_dim,))()
        stds = jnp.clip(jax.nn.softplus(log_stds), STD_MIN, STD_MAX)
        return means, stds
    

""" MLPs for PE Networks """

def const_bias_init(v:float):
    def _init(key, shape, dtype=jnp.float32):
        return jnp.full(shape, v, dtype)
    return _init

def PEMLP(in_rep, out_rep, G, ch: Sequence[int], small_init:bool=True):
    body_emlp = HeadlessEMLP(in_rep, G, ch)
    final_rep = parse_rep(ch, G, len(ch))[-1]
    _head = Linear(final_rep, out_rep, init_scale=0.01 if small_init else 1.0)
    return _PEMLP(body_emlp, _head)

class _PEMLP(nn.Module):
    body_emlp: nn.Module
    head: nn.Module
    @nn.compact
    def __call__(self, *x):
        features = self.body_emlp(*x)
        out = self.head(features)
        return out
    
class PMLP(nn.Module):
    hidden_dims: Sequence[int]
    out_dim: int
    small_init: bool = True
    bias_value: float = 0.0
    @nn.compact
    def __call__(self, *x):
        h = x[0] if len(x) == 1 else jnp.concatenate(x, axis=-1)
        h = MLP(self.hidden_dims, activate_final=True)(h)
        scaling = 0.01 if self.small_init else 1.0
        small_init = lambda *args, **kwargs: default_init()(*args, **kwargs) * scaling
        out = nn.Dense(self.out_dim, kernel_init=small_init, bias_init=const_bias_init(self.bias_value))(h)
        return out