import embodied
import jax
import jax.numpy as jnp
import ruamel.yaml as yaml
tree_map = jax.tree_util.tree_map
sg = lambda x: tree_map(jax.lax.stop_gradient, x)
from jax import debug
import time
import logging
import numpy as np
import functools
logger = logging.getLogger()
class CheckTypesFilter(logging.Filter):
  def filter(self, record):
    return 'check_types' not in record.getMessage()
logger.addFilter(CheckTypesFilter())

from . import behaviors
from . import jaxagent
from . import jaxutils
from . import nets
from . import ninjax as nj

import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions

@jaxagent.Wrapper
class Agent(nj.Module):

  configs = yaml.YAML(typ='safe').load(
      (embodied.Path(__file__).parent / 'configs.yaml').read())

  def __init__(self, obs_space, act_space, step, config):
    self.config = config
    self.obs_space = obs_space
    self.act_space = act_space['action']
    self.step = step
    self.wm = WorldModel(obs_space, act_space, config, name='wm')
    self.task_behavior = getattr(behaviors, config.task_behavior)(
        self.wm, self.act_space, self.config, name='task_behavior')
    if config.expl_behavior == 'None':
      self.expl_behavior = self.task_behavior
    else:
      self.expl_behavior = getattr(behaviors, config.expl_behavior)(
          self.wm, self.act_space, self.config, name='expl_behavior')
    if self.config.multibit_inst:
      assert self.config.inst_prior in ('instnet', 'uniform')
    if self.config.inst_prior != 'instnet':
      assert self.config.first_inst in ('uniform', 'bernoulli')
      assert self.config.task_planning == 'shooting'
      assert self.config.expl_planning == 'shooting'

  def policy_initial(self, batch_size):
    return (
        self.wm.initial(batch_size),
        self.task_behavior.initial(batch_size),
        self.expl_behavior.initial(batch_size))

  def train_initial(self, batch_size):
    return self.wm.initial(batch_size)

  def policy(self, obs, state, mode='train', activate_prior=False):
    self.config.jax.jit and print('Tracing policy function.')
    start_time = time.time()
    obs = self.preprocess(obs)
    (prev_latent, prev_action), task_state, expl_state = state
    embed = self.wm.encoder(obs)
    latent, _ = self.wm.rssm.obs_step(
        prev_latent, prev_action, embed, obs['is_first'])
    self.expl_behavior.policy(latent, expl_state, self.wm.rssm.img_step, activate_prior=activate_prior)
    task_outs, task_state = self.task_behavior.policy(latent, task_state, self.wm.rssm.img_step, mode='task', activate_prior=activate_prior)
    expl_outs, expl_state = self.expl_behavior.policy(latent, expl_state, self.wm.rssm.img_step, mode='expl', activate_prior=activate_prior)
    if mode == 'eval':
      outs = task_outs
    elif mode == 'explore':
      outs = expl_outs
    elif mode == 'train':
      outs = task_outs
    state = ((latent, outs['action']), task_state, expl_state)
    return outs, state

  def train(self, data, state, prob=None):
    self.config.jax.jit and print('Tracing train function.')
    start_time = time.time()
    metrics = {}
    data = self.preprocess(data)
    state, wm_outs, mets = self.wm.train(data, state)
    metrics.update(mets)
    context = {**data, **wm_outs['post']}
    start = tree_map(lambda x: x.reshape([-1] + list(x.shape[2:])), context)
    _, mets = self.task_behavior.train(self.wm.imagine, start, context, prob)
    metrics.update(mets)
    if self.config.expl_behavior != 'None':
      _, mets = self.expl_behavior.train(self.wm.imagine, start, context, prob)
      metrics.update({'expl_' + key: value for key, value in mets.items()})
    outs = {}
    return outs, state, metrics

  def report(self, data):
    start_time = time.time()
    self.config.jax.jit and print('Tracing report function.')
    data = self.preprocess(data)
    report = {}
    report.update(self.wm.report(data))
    mets = self.task_behavior.report(data)
    report.update({f'task_{k}': v for k, v in mets.items()})
    if self.expl_behavior is not self.task_behavior:
      mets = self.expl_behavior.report(data)
      report.update({f'expl_{k}': v for k, v in mets.items()})
    # debug.print("Report Time: {}", time.time()-start_time)
    return report

  def preprocess(self, obs):
    obs = obs.copy()
    for key, value in obs.items():
      if key.startswith('log_') or key in ('key',):
        continue
      if len(value.shape) > 3 and value.dtype == jnp.uint8:
        value = jaxutils.cast_to_compute(value) / 255.0
      else:
        value = value.astype(jnp.float32)
      obs[key] = value
    obs['cont'] = 1.0 - obs['is_terminal'].astype(jnp.float32)
    return obs


class WorldModel(nj.Module):

  def __init__(self, obs_space, act_space, config):
    self.obs_space = obs_space
    self.act_space = act_space['action']
    self.config = config
    shapes = {k: tuple(v.shape) for k, v in obs_space.items()}
    shapes = {k: v for k, v in shapes.items() if not k.startswith('log_')}
    self.encoder = nets.MultiEncoder(shapes, **config.encoder, name='enc')
    self.rssm = nets.RSSM(**config.rssm, name='rssm')
    # import ipdb; ipdb.set_trace()
    self.heads = {
        'decoder': nets.MultiDecoder(shapes, **config.decoder, name='dec'),
        'reward': nets.MLP((), **config.reward_head, name='rew'),
        'cont': nets.MLP((), **config.cont_head, name='cont')}
    self.opt = jaxutils.Optimizer(name='model_opt', **config.model_opt)
    scales = self.config.loss_scales.copy()
    image, vector = scales.pop('image'), scales.pop('vector')
    scales.update({k: image for k in self.heads['decoder'].cnn_shapes})
    scales.update({k: vector for k in self.heads['decoder'].mlp_shapes})
    self.scales = scales

  def initial(self, batch_size):
    prev_latent = self.rssm.initial(batch_size)
    prev_action = jnp.zeros((batch_size, *self.act_space.shape))
    return prev_latent, prev_action

  def train(self, data, state):
    # import ipdb; ipdb.set_trace()
    modules = [self.encoder, self.rssm, *self.heads.values()]
    mets, (state, outs, metrics) = self.opt(
        modules, self.loss, data, state, has_aux=True)
    metrics.update(mets)
    return state, outs, metrics

  def loss(self, data, state):
    # import ipdb; ipdb.set_trace()
    embed = self.encoder(data)
    prev_latent, prev_action = state
    prev_actions = jnp.concatenate([
        prev_action[:, None], data['action'][:, :-1]], 1) # Transform this action into latent action
    post, prior = self.rssm.observe(
        embed, prev_actions, data['is_first'], prev_latent)
    dists = {}
    feats = {**post, 'embed': embed}
    for name, head in self.heads.items():
      out = head(feats if name in self.config.grad_heads else sg(feats))
      out = out if isinstance(out, dict) else {name: out}
      dists.update(out)
    losses = {}
    losses['dyn'] = self.rssm.dyn_loss(post, prior, **self.config.dyn_loss)
    losses['rep'] = self.rssm.rep_loss(post, prior, **self.config.rep_loss)
    for key, dist in dists.items():
      loss = -dist.log_prob(data[key].astype(jnp.float32))
      assert loss.shape == embed.shape[:2], (key, loss.shape)
      losses[key] = loss
    scaled = {k: v * self.scales[k] for k, v in losses.items()}
    model_loss = sum(scaled.values())
    out = {'embed':  embed, 'post': post, 'prior': prior}
    out.update({f'{k}_loss': v for k, v in losses.items()})
    last_latent = {k: v[:, -1] for k, v in post.items()}
    last_action = data['action'][:, -1]
    state = last_latent, last_action
    metrics = self._metrics(data, dists, post, prior, losses, model_loss)
    return model_loss.mean(), (state, out, metrics)

  def imagine(self, policy, inst_head, start, horizon, prob=None):
    first_cont = (1.0 - start['is_terminal']).astype(jnp.float32)
    keys = list(self.rssm.initial(1).keys())
    start = {k: v for k, v in start.items() if k in keys} # start shape 1024, fs
    act_n = self.act_space.shape[0]
    inst_eyes = jnp.repeat(jnp.eye(act_n)[None, :], start['deter'].shape[0], axis=0)
    uni_inst_prior = jaxutils.OneHotDist(probs=jnp.ones(act_n) / act_n) # For Resampling
    bernoulli_inst_prior = tfd.Bernoulli(probs=jnp.ones(act_n) * 0.5)
    if self.config.inst_prior == 'prob':
      indices = jax.random.choice(nj.rng(), jnp.arange(inst_eyes.shape[1]), shape=(inst_eyes.shape[0], ), p=prob)
      selected_insts = inst_eyes[0][indices]
      selected_actions = policy(start, selected_insts)
    elif self.config.inst_prior == 'uniform':
      if not self.config.multibit_inst:
        selected_insts = uni_inst_prior.sample(sample_shape=(inst_eyes.shape[0], ), seed=nj.rng())
        selected_actions = policy(start, selected_insts)
      else:
        selected_insts = bernoulli_inst_prior.sample(sample_shape=(inst_eyes.shape[0], ), seed=nj.rng())
        selected_actions = policy(start, selected_insts)

    elif self.config.inst_prior == 'instnet':
      selected_insts = inst_head(start)
      selected_actions = policy(start, selected_insts)
    else:
      raise Exception("Not Implemented Inst Prior!")
    new_start = {**start, 'action': selected_actions, 'inst': selected_insts}

    def step(prev, _):
      prev = prev.copy()
      state = self.rssm.img_step(prev, prev.pop('action')) 
      return {**state, 'action': policy(state, prev['inst']), 'inst': prev['inst']}
    
    def step_resample(prev, _):
      prev = prev.copy()
      state = self.rssm.img_step(prev, prev.pop('action'))
      new_inst = uni_inst_prior.sample(sample_shape=(new_start['deter'].shape[0], ), seed=nj.rng())
      return {**state, 'action': policy(state, new_inst), 'inst': new_inst}
    
    def step_resample_bernoulli(prev, _):
      prev = prev.copy()
      state = self.rssm.img_step(prev, prev.pop('action'))
      new_inst = bernoulli_inst_prior.sample(sample_shape=(new_start['deter'].shape[0], ), seed=nj.rng())
      return {**state, 'action': policy(state, new_inst), 'inst': new_inst}

    def step_resample_instnet(prev, _):
      prev = prev.copy()
      state = self.rssm.img_step(prev, prev.pop('action'))
      new_inst = inst_head(state)
      return {**state, 'action': policy(state, new_inst), 'inst': new_inst}
    
    def step_resample_prob(prev, _):
      prev = prev.copy()
      state = self.rssm.img_step(prev, prev.pop('action'))
      selected_indices = jax.random.choice(nj.rng(), jnp.arange(inst_eyes.shape[1]), shape=(inst_eyes.shape[0], ), p=prob)
      new_inst = inst_eyes[0][selected_indices]
      return {**state, 'action': policy(state, new_inst), 'inst': new_inst}
    
    if self.config.imag_inst_resample:
      if self.config.inst_prior == 'prob':
        traj = jaxutils.scan(
            step_resample_prob, jnp.arange(horizon), new_start, self.config.imag_unroll)
      elif self.config.inst_prior == 'uniform':
        if self.config.multibit_inst:
          traj = jaxutils.scan(
              step_resample_bernoulli, jnp.arange(horizon), new_start, self.config.imag_unroll)
        else:
          traj = jaxutils.scan(
              step_resample, jnp.arange(horizon), new_start, self.config.imag_unroll)
      elif self.config.inst_prior == 'instnet':
        traj = jaxutils.scan(
            step_resample_instnet, jnp.arange(horizon), new_start, self.config.imag_unroll)
    else:
      traj = jaxutils.scan(
          step, jnp.arange(horizon), new_start, self.config.imag_unroll)
    
    traj = {
        k: jnp.concatenate([new_start[k][None], v], 0) for k, v in traj.items()}
    cont = self.heads['cont'](traj).mode()
    traj['cont'] = jnp.concatenate([first_cont[None], cont[1:]], 0)
    discount = 1 - 1 / self.config.horizon
    traj['weight'] = jnp.cumprod(discount * traj['cont'], 0) / discount
    return traj, inst_eyes

  def report(self, data):
    state = self.initial(len(data['is_first']))
    report = {}
    report.update(self.loss(data, state)[-1][-1])
    context, _ = self.rssm.observe(
        self.encoder(data)[:6, :5], data['action'][:6, :5],
        data['is_first'][:6, :5])
    start = {k: v[:, -1] for k, v in context.items()}
    recon = self.heads['decoder'](context)
    openl = self.heads['decoder'](
        self.rssm.imagine(data['action'][:6, 5:], start))
    for key in self.heads['decoder'].cnn_shapes.keys():
      truth = data[key][:6].astype(jnp.float32)
      model = jnp.concatenate([recon[key].mode()[:, :5], openl[key].mode()], 1)
      error = (model - truth + 1) / 2
      video = jnp.concatenate([truth, model, error], 2)
      report[f'openl_{key}'] = jaxutils.video_grid(video)
    return report

  def _metrics(self, data, dists, post, prior, losses, model_loss):
    entropy = lambda feat: self.rssm.get_dist(feat).entropy()
    metrics = {}
    metrics.update(jaxutils.tensorstats(entropy(prior), 'prior_ent'))
    metrics.update(jaxutils.tensorstats(entropy(post), 'post_ent'))
    metrics.update({f'{k}_loss_mean': v.mean() for k, v in losses.items()})
    metrics.update({f'{k}_loss_std': v.std() for k, v in losses.items()})
    metrics['model_loss_mean'] = model_loss.mean()
    metrics['model_loss_std'] = model_loss.std()
    metrics['reward_max_data'] = jnp.abs(data['reward']).max()
    metrics['reward_max_pred'] = jnp.abs(dists['reward'].mean()).max()
    if 'reward' in dists and not self.config.jax.debug_nans:
      stats = jaxutils.balance_stats(dists['reward'], data['reward'], 0.1)
      metrics.update({f'reward_{k}': v for k, v in stats.items()})
    if 'cont' in dists and not self.config.jax.debug_nans:
      stats = jaxutils.balance_stats(dists['cont'], data['cont'], 0.5)
      metrics.update({f'cont_{k}': v for k, v in stats.items()})
    return metrics


class ImagActorCritic(nj.Module):

  def __init__(self, critics, scales, act_space, config):
    critics = {k: v for k, v in critics.items() if scales[k]}
    for key, scale in scales.items():
      assert not scale or key in critics, key
    self.critics = {k: v for k, v in critics.items() if scales[k]}
    self.scales = scales
    self.act_space = act_space
    self.config = config
    disc = act_space.discrete
    self.grad = config.actor_grad_disc if disc else config.actor_grad_cont
    self.actor = nets.MLP(
        name='actor', dims='deter', shape=act_space.shape, **config.actor,
        dist=config.actor_dist_disc if disc else config.actor_dist_cont)
    self.retnorms = {
        k: jaxutils.Moments(**config.retnorm, name=f'retnorm_{k}')
        for k in critics}
    self.opt = jaxutils.Optimizer(name='actor_opt', **config.actor_opt)

    self.rng = np.random.default_rng(config.seed)
    available = jax.devices(config.jax.platform)
    self.policy_devices = [available[i] for i in self.config.jax.policy_devices]
    self.train_devices = [available[i] for i in self.config.jax.train_devices]
    self.single_device = (self.policy_devices == self.train_devices) and (
        len(self.policy_devices) == 1)
    
    if self.config.multibit_inst:
      self.inst_head = nets.MLP(act_space.shape, **config.multibit_inst_head, name='inst')
    else:
      self.inst_head = nets.MLP(act_space.shape, **config.inst_head, name='inst')
    

  def initial(self, batch_size):
    return {}
  
  def plan(self, state, img_step, activate_prior=False, mode='expl'):
    if mode == 'task':
      planning = self.config.task_planning
      plannum = self.config.task_plannum
    elif mode == 'expl':
      planning = self.config.expl_planning
      plannum = self.config.expl_plannum
    act_n = self.act_space.shape[0]
    inst_eyes = jnp.repeat(jnp.eye(act_n)[None, :], state['deter'].shape[0], axis=0)
    if planning == 'instnet':
      insts = self.inst_head(state).sample((plannum, ), nj.rng())
      insts = jnp.permute_dims(insts, (1, 0, 2))
    
    elif planning == 'shooting':
      if not self.config.multibit_inst:
        insts = inst_eyes
      else:
        bernoulli_prior = tfd.Bernoulli(probs=jnp.ones(act_n) * 0.5)
        insts = bernoulli_prior.sample(sample_shape=(1, plannum, ), seed=nj.rng())
    
    elif planning == 'SMCP_bernoulli':
      if not self.config.multibit_inst:
        raise NotImplementedError("Have not implemented SMCP_bernoulli on single bit.")
      else:
        '''TODO: SMCP_Bernoulli'''
        bernoulli_prior = tfd.Bernoulli(probs=jnp.ones(act_n) * 0.5)
        insts = bernoulli_prior.sample(sample_shape=(1, plannum, ), seed=nj.rng())

    elif planning == 'SMCP_instnet':
      if not self.config.multibit_inst:
        raise NotImplementedError("Have not implemented SMCP_instnet on single bit.")
      else:
        insts = self.inst_head(state).sample((plannum, ), nj.rng())
        insts = jnp.permute_dims(insts, (1, 0, 2))

    if planning not in ['SMCP_instnet', 'SMCP_bernoulli']:
      def actor_forward(inst):
        return self.actor({**state, 'inst': inst}).sample(seed=nj.rng())
      def step_fn(action):
        return img_step(state, action)
      def critic_forward(next_state):
        adv = self.critics['extr'].score_one_step(state, next_state)
        return adv
      actions = jax.vmap(actor_forward, in_axes=1)(insts)
      next_states = jax.vmap(step_fn)(actions)
      one_step_values = jax.vmap(critic_forward)(next_states)
      action_indices = jnp.argmax(one_step_values, axis=0)
      batch_indices = jnp.arange(one_step_values.shape[1])
      best_actions = jnp.stack(actions)[action_indices, batch_indices, :]
      return best_actions, action_indices
    else:
      def actor_forward(inst, s, key):
        action_dist = self.actor({**s, 'inst': inst})
        actions = action_dist.sample(seed=key)
        logpi = action_dist.log_prob(actions)
        return actions, logpi  
      def step_fn(action, s):
        return img_step(s, action)
      def calc_advantage(s, s1, logpi):
        rps =  self.critics['extr'].score_one_step(s, s1) # r_t + V(s_{t+1})
        vs = self.critics['extr'].score_state(s) # V(s_t)
        if self.config.calc_vs:
          rps -= vs
        if self.config.calc_logpi:
          rps -= logpi
        return rps
      def update(carry, keys):
        last_states, last_indices, last_weights, last_insts = carry
        insts = last_insts
        
        if self.config.SMCP_inst_resample:
          if planning == 'SMCP_bernoulli':
            insts = bernoulli_prior.sample(sample_shape=(1, plannum, ), seed=nj.rng()) 
          elif planning == 'SMCP_instnet':
            insts = self.inst_head(state).sample((plannum, ), nj.rng())
            insts = jnp.permute_dims(insts, (1, 0, 2))

        actions, logpis = jax.vmap(actor_forward, in_axes=(1, {k: 1 for k in last_states.keys()}, 0))(insts, last_states, keys[1:])
        cand_states = jax.vmap(step_fn, in_axes=(0, {k: 1 for k in last_states.keys()}), out_axes={k: 1 for k in last_states.keys()})(actions, last_states)
        advantages = jax.vmap(calc_advantage, in_axes=({k: 1 for k in last_states.keys()}, {k: 1 for k in last_states.keys()}, 0))(last_states, cand_states, logpis)
        
        if self.config.SMCP_weight_resample:
          new_weights = jnp.exp(advantages.squeeze(-1))
          new_weights = new_weights / jnp.sum(new_weights)
          traj_indices = jax.random.choice(keys[0], last_indices, shape=(plannum, ), p=new_weights)
        else:
          new_weights = last_weights * jnp.exp(advantages.squeeze(-1))
          new_weights = new_weights / jnp.sum(new_weights)
          traj_indices = jnp.arange(plannum)
        new_states = jax.tree_map(lambda x: x[:, traj_indices], cand_states)
        return (new_states, traj_indices, new_weights, insts), actions
      
      init_state = jax.tree_map(lambda x: jnp.stack([x for _ in range(plannum)], axis=1), state)
      init_indices = jnp.arange(plannum)
      init_weights = jnp.ones(plannum)
      keys = nj.rng(amount=self.config.SMCP_horizon * (1+plannum))
      keys = keys.reshape(self.config.SMCP_horizon, 1+plannum, -1)
      last_states, actions = nj.scan(update, carry=(init_state, init_indices, init_weights, insts), xs=keys) 
      _, last_indices, last_weights, _ = last_states
      first_actions = actions[0]
      if self.config.SMCP_weight_resample:
        selected_indice = jax.random.choice(nj.rng(), last_indices, shape=(1, ))
      else:
        selected_indice = jax.random.choice(nj.rng(), last_indices, shape=(1, ), p=last_weights)
      
      return first_actions[selected_indice][0], selected_indice

  def policy(self, state, img_step, carry, activate_prior=False, mode='expl'):
    action, indice = self.plan(state, img_step, activate_prior=activate_prior, mode=mode)
    return {'action': action, 'log_indice': indice}, carry

  def train(self, imagine, start, context, prob=None):
    def loss(start):
      policy = lambda s, i: self.actor({**(sg(s)), 'inst': i}).sample(seed=nj.rng())
      inst = lambda s: self.inst_head(sg(s)).sample(seed=nj.rng())
      traj, insts = imagine(policy, inst, start, self.config.imag_horizon, prob)
      loss, metrics, sample_traj = self.loss(traj, insts)
      return loss, (sample_traj, metrics)
    if self.config.inst_prior == 'instnet':
      mets, (sample_traj, metrics) = self.opt([self.actor, self.inst_head], loss, start, has_aux=True)
    else:
      mets, (sample_traj, metrics) = self.opt(self.actor, loss, start, has_aux=True)
      
    metrics.update(mets)
    for key, critic in self.critics.items():
      mets = critic.train(sample_traj, self.actor)
      metrics.update({f'{key}_critic_{k}': v for k, v in mets.items()})
    return sample_traj, metrics

  def loss(self, traj, insts):
    metrics = {}


    action_dim = insts.shape[-1]
    bernoulli_inst_prior = tfd.Bernoulli(probs=jnp.ones(action_dim) * 0.5)
    
    start = {k: v[0] for k, v in traj.items() if k != 'inst'}
    if self.config.first_inst == 'bernoulli':
      first_insts = bernoulli_inst_prior.sample(sample_shape=(start['deter'].shape[0], self.config.imag_per_start, ), seed=nj.rng())
    elif self.config.first_inst == 'instnet':
      first_insts = self.inst_head(start).sample(sample_shape=(self.config.imag_per_start, ), seed=nj.rng())
      first_insts = jnp.permute_dims(first_insts, (1, 0, 2))
    elif self.config.first_inst == 'sg_instnet':
      first_insts = self.inst_head(start).sample(sample_shape=(self.config.imag_per_start, ), seed=nj.rng())
      first_insts = jnp.permute_dims(sg(first_insts), (1, 0, 2))
    else:
      raise Exception("Not Implemented first_int!")
    if self.config.multibit_inst:
      insts = first_insts
    def first_step_ls(inst):
      action = self.actor({**(sg(start)), 'inst': inst})
      action_dist = action.distribution
      action_loc = action_dist.loc
      action_scale = action_dist.scale
      return jnp.stack([action_loc, action_scale])
    foo = jax.vmap(first_step_ls, in_axes=1, out_axes=1)(insts)
    first_action_loc, first_action_scale = foo[0], foo[1]

    if not self.config.multibit_inst:
      diag_loc = jnp.diagonal(first_action_loc, axis1=0, axis2=2).T  
      diag_scale = jnp.diagonal(first_action_scale, axis1=0, axis2=2).T

      aii_loc = jnp.repeat(diag_loc[:, None, :], action_dim, axis=1)
      aii_scale = jnp.repeat(diag_scale[:, None, :], action_dim, axis=1)

      aji_loc = jnp.transpose(first_action_loc, (2, 0, 1)) 
      aji_scale = jnp.transpose(first_action_scale, (2, 0, 1))
      aii_dist = tfd.Normal(aii_loc, aii_scale)
      if self.config.sg_kl:
        aji_dist = tfd.Normal(sg(aji_loc), sg(aji_scale))
      else:
        aji_dist = tfd.Normal(aji_loc, aji_scale)
      max_kls_loss = -tfd.kl_divergence(aii_dist, aji_dist) * self.config.loss_scales.contrastive.max_kl

      n_inst = first_action_loc.shape[0]
      bs = first_action_loc.shape[1]
      keys = nj.rng(amount=n_inst*bs)
      keys = keys.reshape((n_inst, bs, keys.shape[-1]))
      def batch_calc(action_loc, action_scale, key):
        shuffled_indices = self.get_shuffled_indice(action_dim, key).astype(jnp.int32)
        shuffled_indices = shuffled_indices.T
        def body_fn(fi, xi):
            return fi[xi]
        shuffled_loc = jax.vmap(body_fn, in_axes=(1, 1), out_axes=1)(action_loc, shuffled_indices)
        shuffled_scale = jax.vmap(body_fn, in_axes=(1, 1), out_axes=1)(action_scale, shuffled_indices)
        return jnp.stack([shuffled_loc, shuffled_scale])
      
      res = jax.vmap(batch_calc, in_axes=(1, 1, 1), out_axes=2)(first_action_loc, first_action_scale, keys)
      shuffled_action_loc = res[0]
      shuffled_action_scale = res[1]

      original_dist = tfd.Normal(first_action_loc, first_action_scale)
      if self.config.sg_kl:
        shuffled_dist = tfd.Normal(sg(shuffled_action_loc), sg(shuffled_action_scale))
      else:
        shuffled_dist = tfd.Normal(shuffled_action_loc, shuffled_action_scale)
      min_kls_loss = tfd.kl_divergence(original_dist, shuffled_dist) * self.config.loss_scales.contrastive.min_kl

    else:
      n_inst = first_action_loc.shape[0]
      bs = first_action_loc.shape[1]
      keys = nj.rng(amount=n_inst*bs*2*action_dim)
      keys = keys.reshape((n_inst, bs, action_dim, 2, keys.shape[-1]))
      def batch_calc_max(action_loc, action_scale, inst, key):
        negative_indices, positive_indices = self.get_contrastive_samples(inst, key)
        negative_indices = negative_indices.T
        positive_indices = positive_indices.T
        def body_fn(fi, xi):
            return fi[xi]
        negative_action_loc = jax.vmap(body_fn, in_axes=(1, 1), out_axes=1)(action_loc, negative_indices)
        negative_action_scale = jax.vmap(body_fn, in_axes=(1, 1), out_axes=1)(action_scale, negative_indices)
        positive_action_loc = jax.vmap(body_fn, in_axes=(1, 1), out_axes=1)(action_loc, positive_indices)
        positive_action_scale = jax.vmap(body_fn, in_axes=(1, 1), out_axes=1)(action_scale, positive_indices)
        return jnp.stack([negative_action_loc, negative_action_scale, positive_action_loc, positive_action_scale])
      res = jax.vmap(batch_calc_max, in_axes=(1, 1, 0, 1), out_axes=2)(first_action_loc, first_action_scale, insts, keys)
      negative_action_locs = res[0]
      negative_action_scales = res[1]
      positive_action_locs = res[2]
      positive_action_scales = res[3]

      original_dist = tfd.Normal(first_action_loc, first_action_scale)
      if self.config.sg_kl:
        negative_dist = tfd.Normal(sg(negative_action_locs), sg(negative_action_scales))
      else:
        negative_dist = tfd.Normal(negative_action_locs, negative_action_scales)
      max_kls_loss = -tfd.kl_divergence(original_dist, negative_dist) * self.config.loss_scales.contrastive.max_kl

      if self.config.sg_kl:
        positive_dist = tfd.Normal(sg(positive_action_locs), sg(positive_action_scales))
      else:
        positive_dist = tfd.Normal(positive_action_locs, positive_action_scales)
      min_kls_loss = tfd.kl_divergence(original_dist, positive_dist) * self.config.loss_scales.contrastive.min_kl
    sample_traj = traj
    
    
    advs = []
    total = sum(self.scales[k] for k in self.critics)
    for key, critic in self.critics.items():
      rew, ret, base = critic.score(sample_traj, self.actor)

      offset, invscale = self.retnorms[key](ret)
      normed_ret = (ret - offset) / invscale
      normed_base = (base - offset) / invscale
      advs.append((normed_ret - normed_base) * self.scales[key] / total)

      metrics.update(jaxutils.tensorstats(rew, f'{key}_reward'))
      metrics.update(jaxutils.tensorstats(ret, f'{key}_return_raw'))
      metrics.update(jaxutils.tensorstats(normed_ret, f'{key}_return_normed'))
      metrics[f'{key}_return_rate'] = (jnp.abs(ret) >= 0.5).mean()
    adv = jnp.stack(advs).sum(0)
    policy = self.actor(sg(sample_traj))
    logpi = policy.log_prob(sg(sample_traj['action']))[:-1]
    loss = {'backprop': -adv, 'reinforce': -logpi * sg(adv)}[self.grad]
    ent = policy.entropy()[:-1]
    loss -= self.config.actent * ent
    loss *= sg(sample_traj['weight'])[:-1]
    loss *= self.config.loss_scales.actor

    metrics.update(self._metrics(sample_traj, policy, logpi, ent, adv))
    metrics['max_kls_loss'] = max_kls_loss.mean()
    metrics['min_kls_loss'] = min_kls_loss.mean()
    loss = loss.mean() + max_kls_loss.mean() + min_kls_loss.mean()

    if self.config.instnet_sparsity:
      sp_loss = sample_traj['inst'].mean()
      metrics['sparsity_loss'] = sp_loss
      loss += sp_loss

    return loss, metrics, sample_traj
  
  def get_shuffled_indice(self, act_n, keys):
    def shuffle_single_column(column, idx, key):
      def slice_left(i, x):
          return x.at[i].set(jnp.where(i<idx, column[i], 0))
      def slice_right(i, x):
          return x.at[i].set(jnp.where(i>idx, column[i], 0))
      l = jax.lax.fori_loop(0, act_n, slice_left, jnp.zeros(act_n))
      r = jax.lax.fori_loop(0, act_n, slice_right, jnp.zeros(act_n))
      def concat(i, x):
          return x.at[i].set(jnp.where(i<idx, l[i], r[i+1]))
      non_diag = jax.lax.fori_loop(0, act_n, concat, jnp.zeros(act_n-1))
      non_diag_shuffle = jax.random.permutation(key, non_diag)
      def concat_shuffle(i, x):
          return x.at[i].set(jnp.where(i != idx, jnp.where(i<idx, non_diag_shuffle[i], non_diag_shuffle[i-1]), column[idx]))
      res = jax.lax.fori_loop(0, act_n, concat_shuffle, jnp.zeros(act_n))
      return res
    column = jnp.stack([jnp.arange(act_n) for _ in range(act_n)])
    idx = jnp.arange(act_n)
    results = jax.vmap(shuffle_single_column, in_axes=(0, 0, 0))(column, idx, keys)
    return results

  def get_contrastive_samples(self, insts, keys):
    bs = insts.shape[-2]
    act_n = insts.shape[-1]
    def sample_single_column(column, key):
      zero_prob = 1-column # sample those zeros
      zero_prob = zero_prob / jnp.sum(zero_prob)
      def get_negative_sample(i, carry):
        x, y = carry
        idx = jax.random.choice(key[i][0], x.shape[0], p=zero_prob)
        x = x.at[i].set(jnp.where(column[i], idx, i))

        new_prob = zero_prob.at[i].set(0)
        new_prob = new_prob / jnp.sum(new_prob)
        idx = jax.random.choice(key[i][1], y.shape[0], p=new_prob)
        y = y.at[i].set(jnp.where(column[i], i, idx))
        return (x, y)
      init_x = jnp.zeros(bs, dtype=jnp.int32)
      init_y = jnp.zeros(bs, dtype=jnp.int32)
      columns = jax.lax.fori_loop(0, bs, get_negative_sample, (init_x, init_y))
      return columns
    return jax.vmap(sample_single_column, in_axes=(1, 1))(insts, keys)

  def _metrics(self, traj, policy, logpi, ent, adv):
    metrics = {}
    ent = policy.entropy()[:-1]
    rand = (ent - policy.minent) / (policy.maxent - policy.minent)
    rand = rand.mean(range(2, len(rand.shape)))
    act = traj['action']
    act = jnp.argmax(act, -1) if self.act_space.discrete else act
    metrics.update(jaxutils.tensorstats(act, 'action'))
    metrics.update(jaxutils.tensorstats(rand, 'policy_randomness'))
    metrics.update(jaxutils.tensorstats(ent, 'policy_entropy'))
    metrics.update(jaxutils.tensorstats(logpi, 'policy_logprob'))
    metrics.update(jaxutils.tensorstats(adv, 'adv'))
    metrics['imag_weight_dist'] = jaxutils.subsample(traj['weight'])
    return metrics


class VFunction(nj.Module):

  def __init__(self, rewfn, config):
    self.rewfn = rewfn
    self.config = config
    self.net = nets.MLP((), name='net', dims='deter', **self.config.critic)
    self.slow = nets.MLP((), name='slow', dims='deter', **self.config.critic)
    self.updater = jaxutils.SlowUpdater(
        self.net, self.slow,
        self.config.slow_critic_fraction,
        self.config.slow_critic_update)
    self.opt = jaxutils.Optimizer(name='critic_opt', **self.config.critic_opt)

  def train(self, traj, actor):
    target = sg(self.score(traj)[1])
    mets, metrics = self.opt(self.net, self.loss, traj, target, has_aux=True)
    metrics.update(mets)
    self.updater()
    return metrics

  def loss(self, traj, target):
    metrics = {}
    traj = {k: v[:-1] for k, v in traj.items()}
    dist = self.net(traj)
    loss = -dist.log_prob(sg(target))
    if self.config.critic_slowreg == 'logprob':
      reg = -dist.log_prob(sg(self.slow(traj).mean()))
    elif self.config.critic_slowreg == 'xent':
      reg = -jnp.einsum(
          '...i,...i->...',
          sg(self.slow(traj).probs),
          jnp.log(dist.probs))
    else:
      raise NotImplementedError(self.config.critic_slowreg)
    loss += self.config.loss_scales.slowreg * reg
    loss = (loss * sg(traj['weight'])).mean()
    loss *= self.config.loss_scales.critic
    metrics = jaxutils.tensorstats(dist.mean())
    return loss, metrics

  def score(self, traj, actor=None):
    rew = self.rewfn(traj)
    assert len(rew) == len(traj['action']) - 1, (
        'should provide rewards for all but last action')
    discount = 1 - 1 / self.config.horizon
    disc = traj['cont'][1:] * discount
    value = self.net(traj).mean()
    vals = [value[-1]]
    interm = rew + disc * value[1:] * (1 - self.config.return_lambda)
    for t in reversed(range(len(disc))):
      vals.append(interm[t] + disc[t] * self.config.return_lambda * vals[-1])
    ret = jnp.stack(list(reversed(vals))[:-1])
    return rew, ret, value[:-1]
  
  def score_one_step(self, state, next_state):
    discount = 1 - 1 / self.config.horizon
    foo = {k: jnp.concat([state[k], next_state[k]], axis=0) for k in state.keys()}
    rew = self.rewfn(foo)
    value = self.net(foo).mean()[1:]
    return rew + discount * value
  
  def score_state(self, state):
    return self.net(state).mean()
