import torch
from torch import nn
from torch import distributions as torchd

import models
import networks
import tools


class Random(nn.Module):

  def __init__(self, config):
    self._config = config

  def actor(self, feat):
    shape = feat.shape[:-1] + [self._config.num_actions]
    if self._config.actor_dist == 'onehot':
      return tools.OneHotDist(torch.zeros(shape))
    else:
      ones = torch.ones(shape)
      return tools.ContDist(torchd.uniform.Uniform(-ones, ones))

  def train(self, start, context):
    return None, {}


class Plan2Explore(nn.Module):

  def __init__(self, config, world_model, reward=None):
    self._config = config
    self._reward = reward
    self._behavior = models.ImagBehavior(config, world_model)
    self.actor = self._behavior.actor
    stoch_size = config.dyn_stoch
    if config.dyn_discrete:
      stoch_size *= config.dyn_discrete
    size = {
        'embed': 32 * config.cnn_depth,
        'stoch': stoch_size,
        'deter': config.dyn_deter,
        'feat': config.dyn_stoch + config.dyn_deter,
    }[self._config.disag_target]
    kw = dict(
        inp_dim=config.dyn_stoch,  # pytorch version
        shape=size, layers=config.disag_layers, units=config.disag_units,
        act=config.act)
    self._networks = [
        networks.DenseHead(**kw) for _ in range(config.disag_models)]
    self._opt = tools.optimizer(config.opt, self.parameters(),
        config.model_lr, config.opt_eps, config.weight_decay)

  def train(self, start, context, data):
    metrics = {}
    stoch = start['stoch']
    if self._config.dyn_discrete:
      stoch = tf.reshape(
          stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]))
    target = {
        'embed': context['embed'],
        'stoch': stoch,
        'deter': start['deter'],
        'feat': context['feat'],
    }[self._config.disag_target]
    inputs = context['feat']
    if self._config.disag_action_cond:
      inputs = tf.concat([inputs, data['action']], -1)
    metrics.update(self._train_ensemble(inputs, target))
    metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1])
    return None, metrics

  def _intrinsic_reward(self, feat, state, action):
    inputs = feat
    if self._config.disag_action_cond:
      inputs = tf.concat([inputs, action], -1)
    preds = [head(inputs, tf.float32).mean() for head in self._networks]
    disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1)
    if self._config.disag_log:
      disag = tf.math.log(disag)
    reward = self._config.expl_intr_scale * disag
    if self._config.expl_extr_scale:
      reward += tf.cast(self._config.expl_extr_scale * self._reward(
          feat, state, action), tf.float32)
    return reward

  def _train_ensemble(self, inputs, targets):
    if self._config.disag_offset:
      targets = targets[:, self._config.disag_offset:]
      inputs = inputs[:, :-self._config.disag_offset]
    targets = tf.stop_gradient(targets)
    inputs = tf.stop_gradient(inputs)
    with tf.GradientTape() as tape:
      preds = [head(inputs) for head in self._networks]
      likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
      loss = -tf.cast(tf.reduce_sum(likes), tf.float32)
    metrics = self._opt(tape, loss, self._networks)
    return metrics
