from util import log
from mpi4py import MPI
import numpy as np
import tensorflow as tf
from operator import itemgetter
from random import shuffle
from util.act import sample_τ_segment, sample_τ_segment_abs
from util.opt import MpiAdam
from util.rms import π_rms, s_rms, a_rms
from const import γ, add_absorbing_state
import time

𝔼 = tf.reduce_mean
Σ = tf.add_n
γ_inv = 1. / (1. - γ)

def train(alg, env, π, D, expert_dataset, batch_size, *,
     π_step, d_step,
     lr_π=1e-4, lr_v=4e-4, lr_d=4e-4,
     π_rollouts=15, d_rollouts=15,
     burnin_steps=0, max_steps=0, random=False):
  if random:
    lr_π = 0.

  comm = MPI.COMM_WORLD
  nw = comm.Get_size()
  rank = comm.Get_rank()

  Opt1 = lambda x: MpiAdam(x, β1=0., β2=0.99)
  Opt2 = lambda x: MpiAdam(x, β1=0., β2=0.99)

  πv_vars = π.vars + π.v_vars
  θ = π.vars

  πopt = Opt1(πv_vars)
  πopt.sync()

  do_irl = D is not None

  if do_irl:
    dopt = Opt2(D.vars)
    dopt.sync()
    if len(D.rtype) == 1:
      @tf.function
      def rwd(s):
        return D.rwd(s)
    else:
      @tf.function
      def rwd(args):
        return D.rwd(*args)
  else:
    def rwd(r_true):
      return r_true

  rgetter = itemgetter(*(D.rtype)) \
      if do_irl else itemgetter('r_true')

  from loss.πloss import get_logpπ, get_logpE, get_adv_ret, get_adv_ret_abs, πvgrad, πvgrad_abs

  if add_absorbing_state:
    vgetter = itemgetter(
        's','u','r','ś','t','b','y','logπ')
    πgetter = itemgetter('s','u','r','logπ','adv','ret','y')
  else:
    vgetter = itemgetter(
        's','u','r','ś','t','b','logπ')
    πgetter = itemgetter('s','u','r','logπ','adv','ret')
  logpgetter = itemgetter(*'su')

  if do_irl:
    if alg == 'gail':
      from loss.dloss import gail_grad as dgrad
    elif alg == 'fairl':
      from loss.dloss import fairl_grad as dgrad
    elif alg == 'airl':
      from loss.dloss import airl_grad as dgrad
    elif alg == 'cairl':
      from loss.dloss import cairl_grad as dgrad

    dgetter = itemgetter(*D.intype)
  if add_absorbing_state:
    τ_sampler = sample_τ_segment_abs(π, env, batch_size)
  else:
    τ_sampler = sample_τ_segment(π, env, batch_size)
  getτ = τ_sampler.__next__
  if do_irl:
    getτE = expert_dataset.get_next_batch

  ep_total = 0
  steps = 0
  i = 0
  step_size = batch_size * π_step * nw

  lr_π_vec = lr_π * tf.ones(Σ([
      tf.math.reduce_prod(x)
      for x in tf.shape_n(π.vars)]), tf.float32)
  lr_v_vec = lr_v * tf.ones(Σ([
      tf.math.reduce_prod(x)
      for x in tf.shape_n(π.v_vars)]), tf.float32)
  lr0 = tf.concat([tf.zeros_like(lr_π_vec), lr_v_vec], 0)
  lr = tf.concat([lr_π_vec, lr_v_vec], 0)

  def locals1(ep_cnt, ep_lens, ep_rets):
    ep_cnt_= tf.cast(ep_cnt,tf.float32)
    ep_len = np.mean(ep_lens, dtype=np.float32)
    ep_ret = np.mean(ep_rets, dtype=np.float32)
    return tf.stack([ep_cnt_, ep_len, ep_ret])

  @tf.function
  def locals2(rs, r_trues, πvals, dvals):
    r = 𝔼(tf.concat(rs, 0))
    r_true = 𝔼(tf.concat(r_trues, 0))
    ep_vals = tf.parallel_stack([r, r_true])
    return tf.concat([ep_vals, 𝔼(πvals, 0),
         tf.reshape(𝔼(dvals, 0),[-1])], 0)

  ep_cnt = 0
  ep_lens = []
  ep_rets = []
  rs = []
  r_trues = []
  πvals = []
  dvals = []
  τs = []
  τπs = []
  τEs = []

  if rank == 0:
    keys = ['ep/ep_cnt', 'ep/ep_len', 'ep/ep_ret',
        'ep/r', 'ep/r_true']
    πkeys = ['pi_loss', 'entropy', 'oob', 'ent_r']
    vkeys = ['vloss', 'vgp']
    πvkeys = ['pi/'+k for k in πkeys] + \
        ['value/'+k for k in vkeys]
    keys += πvkeys
    if do_irl:
      dkeys = D.keys + ['rE']
      if do_irl:
        dkeys = dkeys + ['v', 'vE']
      dkeys = ['disc/'+k for k in dkeys]
      keys += dkeys

  if do_irl:
    τπ = dgetter(getτ())
    τE = getτE(batch_size)
    s_rms.update(tf.concat((τπ[0], τE[0]), 0))
    a_rms.update(tf.concat((τπ[1], τE[1]), 0))

  while True:
    if max_steps and steps >= max_steps:
      break
    train_π = steps >= burnin_steps

    ep_cnt = 0
    ep_lens.clear()
    ep_rets.clear()
    rs.clear()
    r_trues.clear()
    πvals.clear()
    dvals.clear()
    τs.clear()

    start_time = time.time()

    for j in range(π_step):
      τ = getτ()
      logπ = get_logpπ(*logpgetter(τ))
      τ['logπ'] = logπ
      τ['r'] = rwd(rgetter(τ))
      ep_cnt += len(τ['ep_len'])
      ep_lens.extend(τ['ep_len'])
      ep_rets.extend(τ['ep_ret'])
      rs.append(τ['r'])
      r_trues.append(τ['r_true'])
      if add_absorbing_state:
        adv, ret = get_adv_ret_abs(*vgetter(τ))
      else:
        adv, ret = get_adv_ret(*vgetter(τ))
      τ['adv'] = adv
      τ['ret'] = ret
      π_rms.update(τ['s'])
      τs.append(τ)

    for j in range(π_rollouts):
      for τ in τs:
        if add_absorbing_state:
          ṽLv, vs = πvgrad_abs(*πgetter(τ))
        else:
          ṽLv, vs = πvgrad(*πgetter(τ))
        if train_π:
          πopt.update(ṽLv, lr)
        else:
          πopt.update(ṽLv, lr0)
        if j == π_rollouts - 1:
          πvals.append(vs)
      shuffle(τs)

    if do_irl:
      τπs.clear()
      τEs.clear()
      for j in range(d_step):
        τ = getτ()
        τπ = dgetter(τ)
        τE = getτE(batch_size)
        if 'cairl' in alg:
          v́π = D.fwdv(τπ[2])
          v́E = D.fwdv(τE[2])
          τπ = τπ + (v́π,)
          τE = τE + (v́E,)
        if alg != 'gail' and alg != 'fairl':
          logpπ = get_logpπ(*logpgetter(τ))
          logpE = get_logpE(*τE[:2])
          τπ = τπ + (logpπ,)
          τE = τE + (logpE,)
        s_rms.update(tf.concat((τπ[0],τE[0]), 0))
        a_rms.update(tf.concat((τπ[1],τE[1]), 0))
        τπs.append(τπ)
        τEs.append(τE)

      for j in range(d_rollouts):
        for τπ, τE in zip(τπs, τEs):
          ṽLd, vs = dgrad(*τπ, *τE)
          dopt.update(ṽLd, lr_d)
          if j == d_rollouts - 1:
            if alg == 'cairl':
              rE = 𝔼(rwd((τE[0],τE[1])))
            elif alg == 'airl':
              rE = 𝔼(rwd(τE[:3]+(τE[-1],)))
            else:
              rE = 𝔼(rwd(τE[:len(D.rtype)]))
            if 'cairl' in alg:
              v = 𝔼(D.fwdv(τπ[0]))
              vE = 𝔼(D.fwdv(τE[0]))
              vs = tf.concat([vs, [rE, v, vE]], -1)
            else:
              vs = tf.concat([vs, [rE]], -1)
            dvals.append(vs)
        shuffle(τπs)
        shuffle(τEs)

    local_vals1 = locals1(ep_cnt, ep_lens, ep_rets)
    local_vals2 = locals2(rs, r_trues, πvals, dvals)
    local_vals = tf.concat([local_vals1, local_vals2], 0)

    if nw > 1:
      if i == 0:
        vals = np.empty_like(local_vals)
      comm.Allreduce(local_vals, vals, op=MPI.SUM)
      vals /= nw
      vals[0] *= nw
      ep_cnt = int(vals[0])
    else:
      vals = local_vals

    ep_total += ep_cnt
    steps += step_size

    if rank == 0:
      for k, v in zip(keys, vals):
        log.record_tab(k, v)
      log.record_tab('ep/ep_total', ep_total)
      log.record_tab('ep/steps', steps)
      log.dump_tab(i, time.time() - start_time)
    i += 1