from typing import Tuple, Optional

import jax
import jax.numpy as jnp

from jax_rl.agents.actor_critic_temp import ActorCriticTemp, Pre
from jax_rl.datasets import Batch
from jax_rl.networks.common import InfoDict, Params

import collections

def isDict(pars):
    return isinstance(pars, collections.Mapping)

def get_l2(pars):
    basic_l2 = 0.
    equiv_l2 = 0.
    for k, v in pars.items():
        if isDict(v):
            sub_basic_l2, sub_equiv_l2 = get_l2(v)
            basic_l2 += sub_basic_l2
            equiv_l2 += sub_equiv_l2
        else:
            if k.endswith("_basic"):
                basic_l2 += (v**2).sum()
            elif k.endswith("_equiv"):
                equiv_l2 += (v**2).sum()
    return basic_l2, equiv_l2

def target_update(sac: ActorCriticTemp, tau: float) -> ActorCriticTemp:
    new_target_params = jax.tree_util.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), sac.critic.params,
        sac.target_critic.params)

    new_target_critic = sac.target_critic.replace(params=new_target_params)

    return sac.replace(target_critic=new_target_critic)

def update(sac: ActorCriticTemp, batch: Batch, discount: float,soft_critic: bool,
           cbasic_wd:float,cequiv_wd:float, pre:Optional[Pre]) -> Tuple[ActorCriticTemp, InfoDict]:
    
    rng, act_key = jax.random.split(sac.rng)

    lam_tgt = pre.lam_tgt
    next_actions = pre.next_actions
    next_log_probs = pre.next_log_probs
    
    next_q1, next_q2 = sac.target_critic(batch.next_observations, next_actions, lam_tgt)

    next_q = jnp.minimum(next_q1, next_q2)
    target_q = batch.rewards + discount * batch.masks * next_q

    if soft_critic:
        target_q -= discount * batch.masks * sac.temp() * next_log_probs
        
    def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        assert pre is not None
        # assert sac.actor.apply_fn.det_lam, "Currently only support deterministic lambda in PE-SAC."
        
        lam_q = jax.lax.stop_gradient(pre.lam_q)

        if sac.actor.apply_fn.det_lam:
            q1e, q1n, q2e, q2n = sac.critic.heads_apply({'params': critic_params}, batch.observations, batch.actions)
            
            lam_q_b = lam_q if lam_q.ndim == q1n.ndim else lam_q[..., None]
            wE = 1.0 - lam_q_b
            wN = 1.0
            
            cl_N = (q1n - target_q)**2 + (q2n - target_q)**2
            cl_E = (q1e - target_q)**2 + (q2e - target_q)**2
            critic_loss = (wN * cl_N + wE * cl_E).mean()

            basic_l2, equiv_l2 = get_l2(critic_params)
            
            closs = critic_loss + cbasic_wd * basic_l2 + cequiv_wd * equiv_l2
        
            info = {
                'critic_loss': closs, 
                'q1': ((1.0-lam_q_b)*q1e + lam_q_b*q1n).mean(), 
                'q2': ((1.0-lam_q_b)*q2e + lam_q_b*q2n).mean()
            }
            
        else:
            q1, q2 = sac.critic.apply({'params': critic_params}, batch.observations, batch.actions, lam_q)
            critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
            basic_l2, equiv_l2 = get_l2(critic_params)
            closs = critic_loss + cbasic_wd * basic_l2 + cequiv_wd * equiv_l2
            info = {'critic_loss': closs, 'q1': q1.mean(), 'q2': q2.mean()}

        return closs, info

    new_critic, info = sac.critic.apply_gradient(critic_loss_fn)
    new_sac = sac.replace(critic=new_critic, rng=rng)
    info.update({'target_q': target_q.mean()})
    return new_sac, info
