from typing import Tuple, Optional

import jax
import jax.numpy as jnp

from jax_rl.agents.actor_critic_temp import ActorCriticTemp, Pre
# from jax_rl.networks import _RPPNormalTanhPolicy
from jax_rl.datasets import Batch
from jax_rl.networks.common import InfoDict, Params
import collections

def isDict(pars):
    return isinstance(pars, collections.Mapping)

def get_l2(pars):
    basic_l2 = 0.
    equiv_l2 = 0.
    for k, v in pars.items():
        if isDict(v):
            sub_basic_l2, sub_equiv_l2 = get_l2(v)
            basic_l2 += sub_basic_l2
            equiv_l2 += sub_equiv_l2
        else:
            if k.endswith("_basic"):
                basic_l2 += (v**2).sum()
            elif k.endswith("_equiv"):
                equiv_l2 += (v**2).sum()
    return basic_l2, equiv_l2

def critic_pointwise(sac, obs_flat, act_flat):
    def _one(o, a):
        q1, q2 = sac.critic.apply({'params': sac.critic.params},
                                    o[None, :], a[None, :], sac.gater.apply_fn, sac.gater.params)
        return jnp.squeeze(q1, axis=0), jnp.squeeze(q2, axis=0)
    q1, q2 = jax.vmap(_one)(obs_flat, act_flat)
    return q1, q2

def update(sac: ActorCriticTemp,
           batch: Batch,
           basic_wd, equiv_wd,
           pre: Pre) -> Tuple[ActorCriticTemp, InfoDict]:
    rng, key = jax.random.split(sac.rng)

    def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        assert pre is not None
        # assert sac.actor.apply_fn.det_lam, "Currently only support deterministic lambda in PE-SAC."
        
        lam_p = jax.lax.stop_gradient(pre.lam_p)
        
        if sac.actor.apply_fn.det_lam:
            lam_b = lam_p if lam_p.ndim > 1 else lam_p[:, None]
            distA, distB = sac.actor.heads_apply({'params': actor_params}, batch.observations)
            
            keyA, keyB = jax.random.split(key, 2)
            actionsA, log_probsA = distA.sample_and_log_prob(seed=keyA)
            actionsB, log_probsB = distB.sample_and_log_prob(seed=keyB)
            
            q1A, q2A = sac.critic(batch.observations, actionsA, lam_p)
            q1B, q2B = sac.critic(batch.observations, actionsB, lam_p)
            
            qA = jnp.minimum(q1A, q2A)
            qB = jnp.minimum(q1B, q2B) 
            
            wA = 1.0 - lam_b
            wB = jnp.ones_like(wA)

            lossA = (wA * (sac.temp() * log_probsA[..., None] - qA[..., None])).sum() / (wA.sum() + 1e-8)
            lossB = (wB * (sac.temp() * log_probsB[..., None] - qB[..., None])).sum() / (wB.sum() + 1e-8)
            actor_loss = lossA + lossB

            basic_l2, equiv_l2 = get_l2(actor_params)
            actor_loss = actor_loss + basic_wd * basic_l2 + equiv_wd * equiv_l2

            info = {
                'actor_loss': actor_loss,
                # 'entropy': -log_probs.mean()
                'entropy': -(wA * log_probsA + wB * log_probsB).mean(),
                'entropyA': -log_probsA.mean(),
                'entropyB': -log_probsB.mean(),
            }
            
        else:
            dist = sac.actor.apply({'params': actor_params}, batch.observations, lam_p)
            actions, log_probs = dist.sample_and_log_prob(seed=key)
            q1, q2 = sac.critic(batch.observations, actions, lam_p)
            q = jnp.minimum(q1, q2)
            actor_loss = (log_probs * sac.temp() - q).mean()
            
            basic_l2, equiv_l2 = get_l2(actor_params)
            actor_loss = actor_loss + basic_wd * basic_l2 + equiv_wd * equiv_l2
            
            info = {'actor_loss': actor_loss, 'entropy': -log_probs.mean()}
            

        return actor_loss, info

    new_actor, info = sac.actor.apply_gradient(actor_loss_fn)
    new_sac = sac.replace(actor=new_actor, rng=rng)
    return new_sac, info

