from typing import Sequence, Tuple
from functools import partial

import distrax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np

from jax_rl.networks.common import (MLP, Parameter, Params, PRNGKey,
                                    default_init)
# from jax_rl.networks.rpp_emlp_parts import HeadlessRPPEMLP, HeadlessEMLP,parse_rep
# from emlp.nn import uniform_rep
# from rpp.flax import Sequential,Linear
# from emlp.reps import Rep
from jax_rl.networks.pe_mlp import MultiheadPEMLP, MultiheadPMLP


def PENormalTanhPolicy(state_rep, action_rep, action_std_rep, G, ch,
                       hidden_dims_B, action_dim_B,
                       state_transform, inv_action_transform, state_dependent_std=True, small_init=True,
                       det_lam=False):

    actor_A = MultiheadPEMLP(state_rep, action_rep, action_std_rep, G, ch, state_transform, inv_action_transform,
                         state_dependent_std=state_dependent_std, small_init=small_init)
    actor_B = MultiheadPMLP(hidden_dims=hidden_dims_B, action_dim=action_dim_B, state_dependent_std=state_dependent_std, small_init=small_init)
    
    return _PENormalTanhPolicy(actor_A, actor_B, det_lam)

class _PENormalTanhPolicy(nn.Module):
    actor_A: nn.Module # Equiv actor
    actor_B: nn.Module # Non-equiv actor
    det_lam: bool = False
    # proposal_temp: float = 1.0
    # resample_num: int = 32
    # temper_lambda: float = 1.0
    # clamp_lambda:bool=True

    @nn.compact
    def __call__(self,
                 observations: jnp.ndarray,
                 lam: jnp.ndarray,
                 temperature: float = 1.0,
                 ):
        
        muA, stdA = self.actor_A(observations, temperature=temperature)
        muB, stdB = self.actor_B(observations, temperature=temperature)

        lam = lam.astype(jnp.float32)
        lam_b = lam[..., None] if lam.ndim < muA.ndim else lam
        
        if self.det_lam: # Deterministic lambda, using only one of two options
            muH = (1.0 - lam_b) * muA + lam_b * muB
            stdH = (1.0 - lam_b) * stdA + lam_b * stdB
        else:
            varA = stdA**2 + 1e-8
            varB = stdB**2 + 1e-8
            
            invA = 1.0 / varA
            invB = 1.0 / varB
            
            invH = (1.0 - lam_b) * invA + lam_b * invB
            varH = 1.0 / (invH + 1e-8)
            stdH = jnp.sqrt(varH)
            muH = varH * ((1.0 - lam_b) * muA * invA + lam_b * muB * invB)

        base_dist = distrax.MultivariateNormalDiag(
            loc=muH, scale_diag=stdH * temperature)
        return distrax.Transformed(distribution=base_dist,
                                   bijector=distrax.Block(distrax.Tanh(), 1))
        
    def heads(self, observations:jnp.ndarray, temperature: float = 1.0):
        muA, stdA = self.actor_A(observations, temperature=temperature)
        muB, stdB = self.actor_B(observations, temperature=temperature)
        
        distA = distrax.MultivariateNormalDiag(
            loc=muA, scale_diag=stdA * temperature)
        distB = distrax.MultivariateNormalDiag(
            loc=muB, scale_diag=stdB * temperature)
        
        return distrax.Transformed(distribution=distA, bijector=distrax.Block(distrax.Tanh(), 1)), \
               distrax.Transformed(distribution=distB, bijector=distrax.Block(distrax.Tanh(), 1))


@partial(jax.jit, static_argnums=1)
def sample_actions(rng: PRNGKey,
                   actor_def: nn.Module,
                   actor_params: Params,
                   observations: np.ndarray,
                   lam: jnp.ndarray,
                   temperature: float = 1.0,
                   ) -> Tuple[PRNGKey, jnp.ndarray]:
    dist = actor_def.apply({'params': actor_params}, observations, lam, temperature)
    rng, key = jax.random.split(rng)
    return rng, dist.sample(seed=key)


