import re

import chex
import einops
import elements
import embodied.jax
import embodied.jax.nets as nn
import jax
import jax.numpy as jnp
import ninjax as nj
import numpy as np
import optax

from . import rssm

f32 = jnp.float32
i32 = jnp.int32
sg = lambda xs, skip=False: xs if skip else jax.lax.stop_gradient(xs)
sample = lambda xs: jax.tree.map(lambda x: x.sample(nj.seed()), xs)
prefix = lambda xs, p: {f'{p}/{k}': v for k, v in xs.items()}
concat = lambda xs, a: jax.tree.map(lambda *x: jnp.concatenate(x, a), *xs)
isimage = lambda s: s.dtype == np.uint8 and len(s.shape) == 3


class Agent(embodied.jax.Agent):

  banner = [
      r"---  ___                           __   ______ ---",
      r"--- |   \ _ _ ___ __ _ _ __  ___ _ \ \ / /__ / ---",
      r"--- | |) | '_/ -_) _` | '  \/ -_) '/\ V / |_ \ ---",
      r"--- |___/|_| \___\__,_|_|_|_\___|_|  \_/ |___/ ---",
  ]

  def __init__(self, obs_space, act_space, config):
    self.obs_space = obs_space
    self.act_space = act_space
    self.config = config

    exclude = ('is_first', 'is_last', 'is_terminal', 'reward')
    enc_space = {k: v for k, v in obs_space.items() if k not in exclude}
    dec_space = {k: v for k, v in obs_space.items() if k not in exclude}

    self.dyn = {
        'rssm': rssm.RSSM,
    }[config.dyn.typ](act_space, enc_space, dec_space, config, 
                      **config.dyn[config.dyn.typ], name='dyn0')
    
    d1, d2 = config.policy_dist_disc, config.policy_dist_cont
    outs = {k: d1 if v.discrete else d2 for k, v in act_space.items()}
    self.pol = embodied.jax.MLPHead(
        act_space, outs, **config.policy, name='pol0')
    
    self.feat2tensor = lambda feats: jnp.concatenate([
            jnp.concatenate([
                nn.cast(x['deter']),
                nn.cast(x['stoch'].reshape((*x['stoch'].shape[:-2], -1)))], -1)
                    for x in feats[:self.config.feat2tensor_used_hierarchy]
        ], -1)

    scalar = elements.Space(np.float32, ())
    binary = elements.Space(bool, (), 0, 2)
    self.rew = embodied.jax.MLPHead(scalar, **config.rewhead, name='rew')
    self.con = embodied.jax.MLPHead(binary, **config.conhead, name='con')

    self.val = embodied.jax.MLPHead(scalar, **config.value, name='val')
    self.slowval = embodied.jax.SlowModel(
        embodied.jax.MLPHead(scalar, **config.value, name='slowval'),
        source=self.val, **config.slowvalue)

    self.retnorm = embodied.jax.Normalize(**config.retnorm, name='retnorm')
    self.valnorm = embodied.jax.Normalize(**config.valnorm, name='valnorm')
    self.advnorm = embodied.jax.Normalize(**config.advnorm, name='advnorm')
    self.resobsnorms = [
        embodied.jax.Normalize(**config.resobsnorm, name=f'resobsnorm{i}') 
            for i in range(self.config.hierarchy) 
    ]

    self.modules = [
        self.dyn, self.pol, self.rew, self.con, self.val]
    
    self.hdyns = [self.dyn]
    
    assert self.config.hierarchy > 1, "Only hierarchy impl"
    assert self.dyn.enc.imgkeys, "Must have img in obs_space"
    assert self.config.harmony == False, "harmony loss NOT impl"

    henc_space = enc_space # {**enc_space, 'lowerdeter': self.dyn.entry_space['deter']}
    hdec_space = dec_space # {**dec_space, 'lowerdeter': self.dyn.entry_space['deter']}
    for i in range(1, self.config.hierarchy):
      hdec_space = {k: elements.Space(np.uint8, (*v.shape[:2], 6),) 
                        if isimage(v) else v 
                            for k, v in hdec_space.items()}
      self.hdyns.append(
          rssm.RSSM(act_space, henc_space, hdec_space, config, 
                    **config.dyn[config.dyn.typ], name=f'dyn{i}')
      )
    self.modules += (
        self.hdyns[1:]
    )
      
    scales = self.config.loss_scales.copy()
    
    if self.config.harmony:
      self.harmony_rew = nn.Harmonizer(name='har_rew', k=scales.pop('rew', 1.0))
      scales['rew'] = 1.0
      self.harmony_dyn = nn.Harmonizer(name='har_dyn', k=scales.pop('dyn', 1.0))
      scales['dyn'] = 1.0
      self.harmony_rec = nn.Harmonizer(name='har_rec', k=scales.pop('rec', 1.0))
      scales['rec'] = 1.0
      self.modules += [self.harmony_rew, self.harmony_dyn, self.harmony_rec]
      
    self.opt = embodied.jax.Optimizer(
        self.modules, self._make_opt(**config.opt), summary_depth=1,
        name='opt')

    self.rep_scale = scales.pop('rep')

    rec = scales.pop('rec')
    dyn = scales.pop('dyn')
    # mimic = scales.pop('mimic')
    scales.update({f'recon_{k}_{i}': rec for k in dec_space for i in range(self.config.hierarchy)})
    scales.update({f'dyn{i}': dyn for i in range(self.config.hierarchy)})
    # scales.update({f'recon_lowerdeter_{i}': rec for k in dec_space for i in range(1, self.config.hierarchy)})
    # scales.update({f'mimic_action_{i}': mimic for i in range(1, self.config.hierarchy)})

    self.scales = scales
    
    self.mc = dict(
        train = config.trainmc,
        eval = config.evalmc,
    )
    self.mc_len = dict(
        train = config.trainmclen,
        eval = config.evalmclen,
    )

  @property
  def policy_keys(self):
    return '^(' + ''.join([f'dyn{i}|pol{i}|resobsnorm{i}|' for i in range(self.config.hierarchy)]) + 'rew|con|val)/'

  @property
  def ext_space(self):
    spaces = {}
    spaces['consec'] = elements.Space(np.int32)
    spaces['stepid'] = elements.Space(np.uint8, 20)
    if self.config.replay_context:
      spaces.update(elements.tree.flatdict(dict(
          enc=self.dyn.enc.entry_space,
          dyn=tuple([
              hdyn.entry_space for hdyn in self.hdyns
          ]),
          dec=self.dyn.dec.entry_space)))
    return spaces

  def init_policy(self, batch_size):
    zeros = lambda x: jnp.zeros((batch_size, *x.shape), x.dtype)
    return (
        self.dyn.enc.initial(batch_size),
        tuple([
            hdyn.initial(batch_size) for hdyn in self.hdyns
        ]),
        self.dyn.dec.initial(batch_size),
        jax.tree.map(zeros, self.act_space))

  def init_train(self, batch_size):
    return self.init_policy(batch_size)

  def init_report(self, batch_size):
    return self.init_policy(batch_size)
    
  def policy(self, carry, obs, mode='train'):
    (enc_carry, dyn_carrys, dec_carry, prevact) = carry
    reset = obs['is_first']

    dyn_carrys, (dyn_entrys, feat) = self._observe(dyn_carrys, obs, prevact, reset, False)
    policy = self.pol(self.feat2tensor(feat), bdims=1)
    act = sample(policy)

    out = {}
    out['finite'] = elements.tree.flatdict(
        jax.tree.map(
          lambda x: jnp.isfinite(x).all(range(1, x.ndim)),
          {**dict(obs=obs, carry=carry, act=act), 
           **{f'feat{i}': feat[i] for i in range(self.config.hierarchy)}}
        )
    )
    carry = (enc_carry, dyn_carrys, dec_carry, act)
    if self.config.replay_context:
      out.update(elements.tree.flatdict(dict(
          dyn=dyn_entrys)))
      
    if self.config.vis_report and self.dyn.enc.imgkeys:
      out.update({f'log/vis_obs_hint0': feat[0]['vis_obs_hint']})
      out.update({f'log/vis_obs_recon': feat[0]['vis_obs_recon']})
    return carry, act, out
    
  def train(self, carry, data):
    carry, obs, prevact, stepid = self._apply_replay_context(carry, data)
    metrics, (carry, entries, _, mets) = self.opt(
        self.loss, carry, obs, prevact, training=True, has_aux=True)
    metrics.update(mets)
    self.slowval.update()
    outs = {}
    if self.config.replay_context:
      updates = elements.tree.flatdict(dict(
          stepid=stepid, enc=entries[0], dyn=entries[1], dec=entries[2]))
      B, T = obs['is_first'].shape
      assert all(x.shape[:2] == (B, T) for x in updates.values()), (
          (B, T), {k: v.shape for k, v in updates.items()})
      outs['replay'] = updates
    # if self.config.replay.fracs.priority > 0:
    #   outs['replay']['priority'] = losses['model']
    carry = (*carry, {k: data[k][:, -1] for k in self.act_space})
    return carry, outs, metrics

  @staticmethod
  def norm_res_img_obs(rec, true, norm, training):
    res = f32(rec) - f32(true)
    offset, scale = norm(res, training) 
    # offset = res.mean(axis=(-3,-2,-1), keepdims=True)
    # scale = jnp.square(res).mean(axis=(-3,-2,-1), keepdims=True)
    # approximately 68%, 95%, and 99.7% of the values lie within 1, 2, and 3 standard deviations 
    res = (res - offset) / scale / 4. # approximately 95% in [-1, 1]
    # Apply dead zone 
    res = jnp.sign(res) * jnp.maximum(jnp.abs(res) - 0.05, 0)
    # Rescale to [0, 255]
    res = jnp.clip((res + 0.5) * 255, 0, 255).astype(np.uint8)
    return res

  def imagine(self, dyn_carrys, policy, length, training, single=False):
    if single:
      action = policy(sg(dyn_carrys)) if callable(policy) else policy
      actemb = nn.DictConcat(self.act_space, 1)(action)
      new_carrys, feats = [], []
      for i, carry in enumerate(dyn_carrys):
        deter = self.hdyns[i]._core(carry['deter'], carry['stoch'], actemb)
        logit = self.hdyns[i]._prior(deter)
        stoch = nn.cast(self.hdyns[i]._dist(logit).sample(seed=nj.seed()))
        new_carrys.append(nn.cast(dict(deter=deter, stoch=stoch)))
        feats.append(nn.cast(dict(deter=deter, stoch=stoch, logit=logit)))
      new_carrys = tuple(new_carrys)
      feats = tuple(feats)
      return new_carrys, (feats, action)
    else:
      if callable(policy):
        dyn_carrys, (feat, action) = nj.scan(
            lambda c, _: self.imagine(c, policy, 1, training, single=True),
            nn.cast(dyn_carrys), (), length, unroll=length, axis=1)
      else:
        dyn_carrys, (feat, action) = nj.scan(
            lambda c, a: self.imagine(c, a, 1, training, single=True),
            nn.cast(dyn_carrys), nn.cast(policy), length, unroll=length, axis=1)
      return dyn_carrys, feat, action

  def _observe(self, dyn_carrys, obs, prevact, reset, training):
    kw = dict(training=training, single=True)
    imgkey = self.dyn.enc.imgkeys[0]
    # pre compute upper hint obs
    _, imgfeats, imgact = self.imagine(
        dyn_carrys, 
        lambda feat: sample(self.pol(self.feat2tensor(feat), 1)),
        self.config.imag_length_active_gaze, training=False
    )
    assert self.config.imag_length_active_gaze % self.config.imag_obs_decimation == 0, "imag_length_active_gaze must be n*imag_obs_decimation"
    imgfeats= [dict(
        deter=imgfeat['deter'][:, self.config.imag_obs_decimation-1::self.config.imag_obs_decimation],
        stoch=imgfeat['stoch'][:, self.config.imag_obs_decimation-1::self.config.imag_obs_decimation],
    ) for i, imgfeat in enumerate(imgfeats)]

    to_uint8_img = lambda x: jnp.clip(x * 255, 0, 255).astype(jnp.uint8)

    if not self.config.imag_obs_residual:
      imgrecons = [
          self.hdyns[i].dec({}, imgfeat, imgfeat['deter'][:,:,0], False)[-1]
              for i, imgfeat in enumerate(imgfeats)
      ]
      imgrecons = [
          sg({
              key: einops.rearrange(imgrecons[0][key].pred(), 'b t h w c -> b h w (c t)')
                  for key in self.dyn.enc.imgkeys
          })] + [
          sg({
              key: einops.rearrange(imgrecon[key].pred()[...,-3:], 'b t h w c -> b h w (c t)')
                  for key in self.dyn.enc.imgkeys
          }) for i, imgrecon in enumerate(imgrecons[1:])]
      residual_enhanced_hints = [
          jax.tree.map(to_uint8_img,
              {k: (imgrecons[i][k] + imgrecons[i+1][k] - 0.5) for k in imgrecons[i].keys()} # upper residual is centered by 0.5
          ) for i in range(self.config.hierarchy - 1)
      ] + [
          jax.tree.map(to_uint8_img, imgrecons[-1])
      ] # highest hier rssm has no residual enhance
    else:
      imgrecons = [
          self.hdyns[i+1].dec({}, imgfeat, imgfeat['deter'][:,:,0], False)[-1]
              for i, imgfeat in enumerate(imgfeats[1:])
      ]
      imgrecons = [
          sg({
              key: einops.rearrange(imgrecon[key].pred()[...,-3:], 'b t h w c -> b h w (c t)')
                  for key in self.dyn.enc.imgkeys
          }) for i, imgrecon in enumerate(imgrecons)]
      residual_enhanced_hints = [
          jax.tree.map(to_uint8_img, rec) for rec in imgrecons]

    # hier 0 rssm has no lower_residual_obs
    hier_obs = {k: jnp.concatenate([obs[k], residual_enhanced_hints[0][k]], axis=-1) 
                    if k in residual_enhanced_hints[0] 
                    else obs[k]
                        for k in self.dyn.enc.obs_space}
    dyn_carry, dyn_entry, feat = self.dyn.observe(
        dyn_carrys[0], sg(hier_obs), prevact, reset, **kw
    )
    _, _, recons = self.dyn.dec({}, feat, reset, **kw)
    rec_img = self.dyn.recon2obs(recons)
    true = obs
    lower_residual_obs = sg({
          # key: ((jnp.int32(rec) - jnp.int32(true[key]) + 255) / 2).astype(np.uint8)
          key: self.norm_res_img_obs(rec, true[key], self.resobsnorms[0], training) 
              for key, rec in rec_img.items() }) # for upper layer obs in next iter
    
    feat['recon_diff'] = lower_residual_obs
    feat['recon_loss'] = {key: rec.loss(f32(true[key]) / 255) if key in self.obs_space and isimage(self.obs_space[key]) \
                          else rec.loss(true[key]) for key, rec in recons.items()
    }
    feat['recon'] = rec_img[imgkey]
    new_dyn_carrys, new_dyn_entry, new_feat = [dyn_carry], [dyn_entry], [feat]
    # hier 1~H rssm has lower_residual_obs
    for i, (hdyn, dyn_carry) in enumerate(zip(self.hdyns[1:], dyn_carrys[1:])):
      h = i + 1
      hier_obs = {k: jnp.concatenate([obs[k], lower_residual_obs[k]] + ([residual_enhanced_hints[h][k]] if h<len(residual_enhanced_hints) else []),
                                      axis=-1) 
                      if k in lower_residual_obs
                      else obs[k]
                          for k in hdyn.enc.obs_space}
      dyn_carry, dyn_entry, feat = hdyn.observe(
          dyn_carry, sg(hier_obs), prevact, reset, **kw
      )
      _, _, recons = hdyn.dec({}, feat, reset, **kw)
      rec_img = hdyn.recon2obs(recons) # channel 0:3 obs; channel 4:7 lower_residual_obs
      true = {k: jnp.concatenate([obs[k], lower_residual_obs[k]], axis=-1) 
                  if k in lower_residual_obs
                  else obs[k]
                      for k in hdyn.enc.obs_space}

      lower_residual_obs = sg({
            # key: ((jnp.int32(rec) - jnp.int32(true[key]) + 255) / 2).astype(np.uint8)
            key: self.norm_res_img_obs(rec[..., -3:], true[key][..., -3:], self.resobsnorms[i], training) 
                for key, rec in rec_img.items() }) # for upper layer obs in next iter
      
      feat['recon_diff'] = lower_residual_obs
      feat['recon_loss'] = {key: rec.loss(f32(true[key]) / 255) if key in self.obs_space and isimage(self.obs_space[key]) \
                            else rec.loss(true[key]) for key, rec in recons.items()
      }
      feat['recon'] = rec_img[imgkey][..., -3:]
      new_dyn_carrys.append(dyn_carry)
      new_dyn_entry.append(dyn_entry)
      new_feat.append(feat)

    if self.config.vis_report and not training:
      new_feat[0]['vis_obs_hint'] = jnp.concatenate(
              [jnp.concatenate([
                  obs[imgkey], 
                  einops.rearrange(residual_enhanced_hints[i][imgkey], 'b h w (c t) -> b h (t w) c', c=3)],
              2) for i in range(len(residual_enhanced_hints))], 
          1
      )
      vis_obs_recon0 = jnp.concatenate([obs[imgkey], new_feat[0]['recon']], 2)
      new_feat[0]['vis_obs_recon'] = jnp.concatenate(
          [vis_obs_recon0] + [
              jnp.concatenate([new_feat[i-1]['recon_diff'][imgkey], 
                               new_feat[i]['recon']], 2) for i in range(1, self.config.hierarchy)
          ],
          1)
    dyn_carrys = tuple(new_dyn_carrys)
    dyn_entries = tuple(new_dyn_entry)
    feats = tuple(new_feat)
    return dyn_carrys, (dyn_entries, feats)

  def loss(self, carry, obs, prevact, training):
    enc_carry, dyn_carrys, dec_carry = carry
    reset = obs['is_first']
    B, T = reset.shape
    losses = {}
    metrics = {}

    # World model
    dyn_carrys, prevact = nn.cast((dyn_carrys, prevact))
    dyn_carrys, (dyn_entries, repfeats) = nj.scan(
        lambda carry, inputs: self._observe(carry, *inputs, training), 
        dyn_carrys, (obs, prevact, reset), 
        unroll=jax.tree.leaves(obs)[0].shape[1], 
        axis=1
    )

    for i, repfeat in enumerate(repfeats):
      feat = {k: repfeat[k] for k in ('deter', 'stoch', 'logit')}
      prior = self.dyn._prior(feat['deter'])
      post = feat['logit']
      dyn = self.dyn._dist(sg(post)).kl(self.dyn._dist(prior))
      rep = self.dyn._dist(post).kl(self.dyn._dist(sg(prior)))
      if self.dyn.free_nats:
        dyn = jnp.maximum(dyn, self.dyn.free_nats)
        rep = jnp.maximum(rep, self.dyn.free_nats)
      loss_dyn = dyn + self.rep_scale * rep
      losses[f'dyn{i}'] = self.harmony_dyn(loss_dyn) if self.config.harmony else loss_dyn
      # losses['rep'] = rep
      metrics[f'dyn_ent{i}'] = self.dyn._dist(prior).entropy().mean()
      metrics[f'rep_ent{i}'] = self.dyn._dist(post).entropy().mean()
      metrics[f'loss_dyn_prior{i}'] = dyn.mean()
      metrics[f'loss_dyn_post{i}'] = rep.mean()

      # Decoder
      for i, repfeat in enumerate(repfeats):
        for key, loss_rec in repfeat['recon_loss'].items():
          losses[f'recon_{key}_{i}'] = self.harmony_rec(loss_rec) if self.config.harmony else loss_rec
    
    repfeat = tuple(
        [{k: repfeats[i][k] for k in ('deter', 'stoch', 'logit')} for i in range(self.config.hierarchy)]
    )
    
    inp = sg(self.feat2tensor(repfeat), skip=self.config.reward_grad)
    loss_rew = self.rew(inp, 2).loss(obs['reward'])
    losses['rew'] = self.harmony_rew(loss_rew) if self.config.harmony else loss_rew
    
    con = f32(~obs['is_terminal'])
    if self.config.contdisc:
      con *= 1 - 1 / self.config.horizon
    losses['con'] = self.con(self.feat2tensor(repfeat), 2).loss(con)

    if self.config.harmony:
      harmony_s_dyn = self.harmony_dyn.get_harmony().mean()
      harmony_s_rew = self.harmony_rew.get_harmony().mean()
      harmony_s_rec = self.harmony_rec.get_harmony().mean()
      metrics.update(dict(
          har_s_dyn=harmony_s_dyn,
          har_s_rew=harmony_s_rew,
          har_s_rec=harmony_s_rec,
          har_coef_dyn = 1 / (jnp.exp(harmony_s_dyn)),
          har_coef_rew = 1 / (jnp.exp(harmony_s_rew)),
          har_coef_rec = 1 / (jnp.exp(harmony_s_rec)),
          har_sigma_dyn = jnp.exp(harmony_s_dyn),
          har_sigma_rew = jnp.exp(harmony_s_rew),
          har_sigma_rec = jnp.exp(harmony_s_rec),
      ))

    B, T = reset.shape
    shapes = {k: v.shape for k, v in losses.items()}
    assert all(x == (B, T) for x in shapes.values()), ((B, T), shapes)

    # Imagination
    K = min(self.config.imag_last or T, T)
    H = self.config.imag_length
    starts = self.dyn.starts(dyn_entries, dyn_carrys, K)
    policyfn = lambda feat: sample(self.pol(self.feat2tensor(feat), 1))
    _, imgfeat, imgprevact = self.imagine(starts, policyfn, H, training)
    first = jax.tree.map(
        lambda x: x[:, -K:].reshape((B * K, 1, *x.shape[2:])), repfeat)
    imgfeat = concat([sg(first, skip=self.config.ac_grads), sg(imgfeat)], 1)
    lastact = policyfn(jax.tree.map(lambda x: x[:, -1], imgfeat))
    lastact = jax.tree.map(lambda x: x[:, None], lastact)
    imgact = concat([imgprevact, lastact], 1)
    assert all(x.shape[:2] == (B * K, H + 1) for x in jax.tree.leaves(imgfeat))
    assert all(x.shape[:2] == (B * K, H + 1) for x in jax.tree.leaves(imgact))
    inp = self.feat2tensor(imgfeat)
    los, imgloss_out, mets = imag_loss(
        imgact,
        self.rew(inp, 2).pred(),
        self.con(inp, 2).prob(1),
        self.pol(inp, 2),
        self.val(inp, 2),
        self.slowval(inp, 2),
        self.retnorm, self.valnorm, self.advnorm,
        update=training,
        contdisc=self.config.contdisc,
        horizon=self.config.horizon,
        **self.config.imag_loss)
    losses.update({k: v.mean(1).reshape((B, K)) for k, v in los.items()})
    metrics.update(mets)

    # Replay
    if self.config.repval_loss:
      feat = sg(repfeat, skip=self.config.repval_grad)
      last, term, rew = [obs[k] for k in ('is_last', 'is_terminal', 'reward')]
      boot = imgloss_out['ret'][:, 0].reshape(B, K)
      feat, last, term, rew, boot = jax.tree.map(
          lambda x: x[:, -K:], (feat, last, term, rew, boot))
      inp = self.feat2tensor(feat)
      los, reploss_out, mets = repl_loss(
          last, term, rew, boot,
          self.val(inp, 2),
          self.slowval(inp, 2),
          self.valnorm,
          update=training,
          horizon=self.config.horizon,
          **self.config.repl_loss)
      losses.update(los)
      metrics.update(prefix(mets, 'reploss'))

    assert set(losses.keys()) == set(self.scales.keys()), (
        sorted(losses.keys()), sorted(self.scales.keys()))
    metrics.update({f'loss/{k}': v.mean() for k, v in losses.items()})
    loss = sum([v.mean() * self.scales[k] for k, v in losses.items()])

    carry = (enc_carry, dyn_carrys, dec_carry)
    entries = ({}, dyn_entries, {})
    outs = {'repfeat': repfeat, 'losses': losses}
    return loss, (carry, entries, outs, metrics)

  def report(self, carry, data):
    if not self.config.report:
      return carry, {}

    carry, obs, prevact, _ = self._apply_replay_context(carry, data)
    (enc_carry, dyn_carrys, dec_carry) = carry
    B, T = obs['is_first'].shape
    RB = min(6, B)
    metrics = {}

    # Train metrics
    _, (new_carry, entries, outs, mets) = self.loss(
        carry, obs, prevact, training=False)
    mets.update(mets)

    # Grad norms
    if self.config.report_gradnorms:
      for key in self.scales:
        try:
          lossfn = lambda data, carry: self.loss(
              carry, obs, prevact, training=False)[1][2]['losses'][key].mean()
          grad = nj.grad(lossfn, self.modules)(data, carry)[-1]
          metrics[f'gradnorm/{key}'] = optax.global_norm(grad)
        except KeyError:
          print(f'Skipping gradnorm summary for missing loss: {key}')

    # if self.config.vis_report:
    #   # Open loop
    #   firsthalf = lambda xs: jax.tree.map(lambda x: x[:RB, :T // 2], xs)
    #   secondhalf = lambda xs: jax.tree.map(lambda x: x[:RB, T // 2:], xs)
    #   dyn_carrys = jax.tree.map(lambda x: x[:RB], dyn_carrys)
    #   dec_carry = jax.tree.map(lambda x: x[:RB], dec_carry)

    #   dyn_carrys, prevact = nn.cast((dyn_carrys, prevact))
    #   dyn_carrys, (_, obsfeats) = nj.scan(
    #       lambda carry, inputs: self._observe(carry, *inputs, training=False), 
    #       dyn_carrys, (firsthalf(obs), firsthalf(prevact), firsthalf(obs['is_first'])), 
    #       unroll=jax.tree.leaves(firsthalf(prevact))[0].shape[1], 
    #       axis=1
    #   )

    #   _, imgfeat, _ = self.dyn.imagine(
    #       dyn_carrys[0], secondhalf(prevact), length=T - T // 2, training=False)
    #   dec_carry, _, obsrecons = self.dyn.dec(
    #       dec_carry, obsfeats[0], firsthalf(obs['is_first']), training=False)
    #   dec_carry, _, imgrecons = self.dyn.dec(
    #       dec_carry, imgfeat, jnp.zeros_like(secondhalf(obs['is_first'])),
    #       training=False)

    #   # Video preds
    #   for key in self.dyn.dec.imgkeys:
    #     assert obs[key].dtype == jnp.uint8
    #     true = obs[key][:RB]
    #     pred = jnp.concatenate([obsrecons[key].pred(), imgrecons[key].pred()], 1)
    #     pred = jnp.clip(pred * 255, 0, 255).astype(jnp.uint8)
    #     error = ((i32(pred) - i32(true) + 255) / 2).astype(np.uint8)
    #     video = jnp.concatenate([true, pred, error], 2)

    #     video = jnp.pad(video, [[0, 0], [0, 0], [2, 2], [2, 2], [0, 0]])
    #     mask = jnp.zeros(video.shape, bool).at[:, :, 2:-2, 2:-2, :].set(True)
    #     border = jnp.full((T, 3), jnp.array([0, 255, 0]), jnp.uint8)
    #     border = border.at[T // 2:].set(jnp.array([255, 0, 0], jnp.uint8))
    #     video = jnp.where(mask, video, border[None, :, None, None, :])
    #     video = jnp.concatenate([video, 0 * video[:, :10]], 1)

    #     B, T, H, W, C = video.shape
    #     grid = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
    #     metrics[f'openloop/{key}'] = grid

    carry = (*new_carry, {k: data[k][:, -1] for k in self.act_space})
    return carry, metrics

  def _apply_replay_context(self, carry, data):
    (enc_carry, dyn_carrys, dec_carry, prevact) = carry
    carry = (enc_carry, dyn_carrys, dec_carry)
    stepid = data['stepid']
    obs = {k: data[k] for k in self.obs_space}
    prepend = lambda x, y: jnp.concatenate([x[:, None], y[:, :-1]], 1)
    prevact = {k: prepend(prevact[k], data[k]) for k in self.act_space}
    if not self.config.replay_context:
      return carry, obs, prevact, stepid

    K = self.config.replay_context
    nested = elements.tree.nestdict(data)
    entries = [nested.get(k, {}) for k in ('enc', 'dyn', 'dec')]
    lhs = lambda xs: jax.tree.map(lambda x: x[:, :K], xs)
    rhs = lambda xs: jax.tree.map(lambda x: x[:, K:], xs)
    rep_carry = (
        {},
        tuple(
            [self.hdyns[i].truncate(lhs(entries[1][i]), dyn_carrys[i]) for i in range(self.config.hierarchy)]),
        {})
    rep_obs = {k: rhs(data[k]) for k in self.obs_space}
    rep_prevact = {k: data[k][:, K - 1: -1] for k in self.act_space}
    rep_stepid = rhs(stepid)

    first_chunk = (data['consec'][:, 0] == 0)
    carry, obs, prevact, stepid = jax.tree.map(
        lambda normal, replay: nn.where(first_chunk, replay, normal),
        (carry, rhs(obs), rhs(prevact), rhs(stepid)),
        (rep_carry, rep_obs, rep_prevact, rep_stepid))
    return carry, obs, prevact, stepid

  def _make_opt(
      self,
      lr: float = 4e-5,
      agc: float = 0.3,
      eps: float = 1e-20,
      beta1: float = 0.9,
      beta2: float = 0.999,
      momentum: bool = True,
      nesterov: bool = False,
      wd: float = 0.0,
      wdregex: str = r'/kernel$',
      schedule: str = 'const',
      warmup: int = 1000,
      anneal: int = 0,
  ):
    chain = []
    chain.append(embodied.jax.opt.clip_by_agc(agc))
    chain.append(embodied.jax.opt.scale_by_rms(beta2, eps))
    chain.append(embodied.jax.opt.scale_by_momentum(beta1, nesterov))
    if wd:
      assert not wdregex[0].isnumeric(), wdregex
      pattern = re.compile(wdregex)
      wdmask = lambda params: {k: bool(pattern.search(k)) for k in params}
      chain.append(optax.add_decayed_weights(wd, wdmask))
    assert anneal > 0 or schedule == 'const'
    if schedule == 'const':
      sched = optax.constant_schedule(lr)
    elif schedule == 'linear':
      sched = optax.linear_schedule(lr, 0.1 * lr, anneal - warmup)
    elif schedule == 'cosine':
      sched = optax.cosine_decay_schedule(lr, anneal - warmup, 0.1 * lr)
    else:
      raise NotImplementedError(schedule)
    if warmup:
      ramp = optax.linear_schedule(0.0, lr, warmup)
      sched = optax.join_schedules([ramp, sched], [warmup])
    chain.append(optax.scale_by_learning_rate(sched))
    return optax.chain(*chain)


def imag_loss(
    act, rew, con,
    policy, value, slowvalue,
    retnorm, valnorm, advnorm,
    update,
    contdisc=True,
    slowtar=True,
    horizon=333,
    lam=0.95,
    actent=3e-4,
    slowreg=1.0,
):
  losses = {}
  metrics = {}

  voffset, vscale = valnorm.stats()
  val = value.pred() * vscale + voffset
  slowval = slowvalue.pred() * vscale + voffset
  tarval = slowval if slowtar else val
  disc = 1 if contdisc else 1 - 1 / horizon
  weight = jnp.cumprod(disc * con, 1) / disc
  last = jnp.zeros_like(con)
  term = 1 - con
  ret = lambda_return(last, term, rew, tarval, tarval, disc, lam)

  roffset, rscale = retnorm(ret, update)
  adv = (ret - tarval[:, :-1]) / rscale
  aoffset, ascale = advnorm(adv, update)
  adv_normed = (adv - aoffset) / ascale
  logpi = sum([v.logp(sg(act[k]))[:, :-1] for k, v in policy.items()])
  ents = {k: v.entropy()[:, :-1] for k, v in policy.items()}
  policy_loss = sg(weight[:, :-1]) * -(
      logpi * sg(adv_normed) + actent * sum(ents.values()))
  losses['policy'] = policy_loss

  voffset, vscale = valnorm(ret, update)
  tar_normed = (ret - voffset) / vscale
  tar_padded = jnp.concatenate([tar_normed, 0 * tar_normed[:, -1:]], 1)
  losses['value'] = sg(weight[:, :-1]) * (
      value.loss(sg(tar_padded)) +
      slowreg * value.loss(sg(slowvalue.pred())))[:, :-1]

  ret_normed = (ret - roffset) / rscale
  metrics['adv'] = adv.mean()
  metrics['adv_std'] = adv.std()
  metrics['adv_mag'] = jnp.abs(adv).mean()
  metrics['rew'] = rew.mean()
  metrics['con'] = con.mean()
  metrics['ret'] = ret_normed.mean()
  metrics['val'] = val.mean()
  metrics['tar'] = tar_normed.mean()
  metrics['weight'] = weight.mean()
  metrics['slowval'] = slowval.mean()
  metrics['ret_min'] = ret_normed.min()
  metrics['ret_max'] = ret_normed.max()
  metrics['ret_rate'] = (jnp.abs(ret_normed) >= 1.0).mean()
  for k in act:
    metrics[f'ent/{k}'] = ents[k].mean()
    if hasattr(policy[k], 'minent'):
      lo, hi = policy[k].minent, policy[k].maxent
      metrics[f'rand/{k}'] = (ents[k].mean() - lo) / (hi - lo)

  outs = {}
  outs['ret'] = ret
  return losses, outs, metrics


def repl_loss(
    last, term, rew, boot,
    value, slowvalue, valnorm,
    update=True,
    slowreg=1.0,
    slowtar=True,
    horizon=333,
    lam=0.95,
):
  losses = {}

  voffset, vscale = valnorm.stats()
  val = value.pred() * vscale + voffset
  slowval = slowvalue.pred() * vscale + voffset
  tarval = slowval if slowtar else val
  disc = 1 - 1 / horizon
  weight = f32(~last)
  ret = lambda_return(last, term, rew, tarval, boot, disc, lam)

  voffset, vscale = valnorm(ret, update)
  ret_normed = (ret - voffset) / vscale
  ret_padded = jnp.concatenate([ret_normed, 0 * ret_normed[:, -1:]], 1)
  losses['repval'] = weight[:, :-1] * (
      value.loss(sg(ret_padded)) +
      slowreg * value.loss(sg(slowvalue.pred())))[:, :-1]

  outs = {}
  outs['ret'] = ret
  metrics = {}

  return losses, outs, metrics


def lambda_return(last, term, rew, val, boot, disc, lam):
  chex.assert_equal_shape((last, term, rew, val, boot))
  rets = [boot[:, -1]]
  live = (1 - f32(term))[:, 1:] * disc
  cont = (1 - f32(last))[:, 1:] * lam
  interm = rew[:, 1:] + (1 - cont) * live * boot[:, 1:]
  for t in reversed(range(live.shape[1])):
    rets.append(interm[:, t] + live[:, t] * cont[:, t] * rets[-1])
  return jnp.stack(list(reversed(rets))[:-1], 1)
