import argparse
from mpi4py import MPI
from env import Env, AbsEnv
import os.path as osp
import numpy as np
from util.rnd import set_seed
from os import path
from util import log
import math
import json

def argparser(args=None):
  prs = argparse.ArgumentParser('CAIRL')
  add = prs.add_argument
  add('--env_id', default='Hopper-v3')
  add('--envE_id', default=None)
  add('--seed', type=int, default=0)
  add('--expr_dir', default='expr')
  add('--pi_step', type=int, default=3)
  add('--d_step', type=int, default=3)
  add('--alg', type=str, choices=[
      'sppo', 'gail',
      'airl', 'cairl',
      'fairl'], default='gail')
  add('--policy', type=str, choices=['gaussian', 'gmm'], default='gaussian')
  add('--gamma', type=float, default=0.99)
  add('--lamda', type=float, default=0.97)
  add('--alpha', type=float, default=1e-2)
  add('--num_steps', type=int, default=5e6)
  add('--burnin_steps', type=int, default=0)
  add('--gp_coeff', type=float, default=1e-4)
  add('--num_demo', type=int, default=1000)
  add('--random_agent', dest='random_agent', action='store_true')
  prs.set_defaults(random_agent=False)

  add('--add_absorbing_state', dest='add_absorbing_state', action='store_true')
  prs.set_defaults(add_absorbing_state=False)
  add('--ablation', type=str, choices=['none', 'no_u', 'no_d_psi', 'no_softplus'], default='none')

  add('--tag', type=str, default='')
  return prs.parse_args(args=args)

def to_sci(i):
  absi = np.abs(i)
  if 1 <= absi <= 99 or absi == 0.0: return str(int(i))
  elif 0.1 <= absi < 1: return str(i)
  p = int(math.floor(np.log(absi) / np.log(10)))
  return str(int(np.sign(i) * (absi / (10 ** p))))+'e'+str(p)

def get_expr_name(args):
  expr_name = f'{args.alg}.'
  expr_name += args.env_id.split('-')[0]
  if  args.alg != 'rl' and args.envE_id and args.envE_id  != args.env_id:
    expr_name += ('-' + args.envE_id.split("-")[0])
  expr_name += f'.α_{to_sci(args.alpha)}.η_{to_sci(args.gp_coeff)}'
  if args.alg != 'sppo':
    expr_name += f".n_{args.num_demo}"

  if args.add_absorbing_state:
    expr_name += ".add_absorbing_state"

  # ablation
  if args.ablation != 'none':
    expr_name += "."
    expr_name += args.ablation


  expr_name += f'.seed_{args.seed}'
  if args.random_agent:
    expr_name += ".random"
  if args.tag != '':
    expr_name += ("." + args.tag)
  return expr_name

def main(args):
  nh = 256
  if 'Humanoid' in args.env_id: # 8 threads
    args.num_steps = int(5e7)
  if "Walker2d" in args.env_id: # 3 threads
    args.num_steps = int(3e7)
  elif 'MultiGoalAnt' in args.env_id: # 16 threads
    args.num_steps = int(3e7)
    args.policy = 'gmm'
    if args.alg != 'sppo':
      nh = 400
  elif 'MultiGoal' in args.env_id or 'Asymmetric' in args.env_id:
    if 'Hard' in args.env_id:
      args.gamma = 0.5
      args.alpha = 0.1
      args.lamda = 0.0
      args.num_steps = int(1.5e7)
      args.policy = 'gmm'
    else:
      args.gamma = 0.5
      args.alpha = 1.0
      args.lamda = 0.0
      args.num_steps = int(1e7)
      args.policy = 'gmm'

  expr_name = get_expr_name(args)
  rank = MPI.COMM_WORLD.Get_rank()
  expr_path = path.join(args.expr_dir, expr_name)

  #
  if 'Humanoid' in args.env_id:
    if args.alg == 'gail' or args.alg == 'airl' or args.alg == 'fairl':
      args.alpha *= 0.1

  if args.add_absorbing_state:
    env = AbsEnv(args.env_id)
  else:
    env = Env(args.env_id)

  if args.alg == 'sppo' or args.envE_id is None:
    args.envE_id = args.env_id
    envE = env
  else:
    if args.add_absorbing_state:
      envE = AbsEnv(args.envE_id)
    else:
      envE = Env(args.envE_id)

  set_seed(args.seed, env, envE)
  from const import set_const
  set_const(args.alpha,
      args.gamma, args.lamda,
      args.gp_coeff,
      np.prod(env.S.shape), np.prod(env.A.shape), nh, env.A.high,
      args.add_absorbing_state)

  from util.pd import set_pd
  set_pd(args.policy)

  from util.rms import init
  init()

  if rank == 0:
    log.configure(expr_path, formats=['stdout', 'csv', 'tb'])

  if args.policy == 'gaussian':
    from nn.pi import GaussianPolicy
    π = GaussianPolicy()
  elif args.policy == 'gmm':
    from nn.pi import GMMPolicy
    π =  GMMPolicy()

  from loss import set_πv_loss
  set_πv_loss(π.vars, π.v_vars, π.fwd, π.fwdv)

  if args.alg == 'sppo':
    D = dataset = None
  else:
    from nn import disc
    if 'gail' in args.alg:
      D = disc.GAIL()
    elif args.alg == 'fairl':
      D = disc.FAIRL()
    elif args.alg == 'airl':
      D = disc.AIRL()
    elif args.alg == 'cairl':
      if args.ablation == 'no_u':
        D = disc.CAIRL1()
      elif args.ablation == 'no_d_psi':
        D = disc.CAIRL2()
      elif args.ablation == 'no_softplus':
        D = disc.CAIRL3()
      else:
        D = disc.CAIRL()

    from loss import set_d_loss
    set_d_loss(D.vars, D.loss_gp)

    from data import get_trj
    if args.add_absorbing_state:
      expert_traj_path = f'data/trj{args.num_demo}_absorb.{args.envE_id.split("-")[0]}.npz'
    else:
      expert_traj_path = f'data/trj{args.num_demo}.{args.envE_id.split("-")[0]}.npz'

    dataset = get_trj(expert_traj_path, D.intype)

  if args.alg == 'sppo':
    batch_size = 2048
  else:
    batch_size = min(2048, dataset.n)


  from train import train
  train(args.alg, env, π,
      D, dataset, batch_size,
      π_step=args.pi_step,
      d_step=args.d_step,
      max_steps=args.num_steps,
      burnin_steps=args.burnin_steps, random=args.random_agent)

  log.close()

  if rank == 0:
    param_dict = {}
    π_params = []
    for var in π.vars:
      π_params.append(list(map(float,var.numpy().flatten())))
    param_dict["pi"] = π_params

    from util.rms import π_rms
    π_rms_params = []
    for var in π_rms.vars:
      π_rms_params.append(list(map(float,var.numpy().flatten())))
    param_dict["pi_rms"] = π_rms_params

    if D is not None:
      d_params = []
      for var in D.vars:
        d_params.append(list(map(float,var.numpy().flatten())))
      param_dict["d"] = d_params

      from util.rms import s_rms, a_rms
      s_rms_params = []
      a_rms_params = []
      for var in s_rms.vars:
        s_rms_params.append(list(map(float,var.numpy().flatten())))
      for var in a_rms.vars:
        a_rms_params.append(list(map(float,var.numpy().flatten())))
      param_dict["s_rms"] = s_rms_params
      param_dict["a_rms"] = a_rms_params

    param_path = osp.join(expr_path, 'param.json')
    with open(param_path, "w") as f:
      f.write(json.dumps(param_dict))

if __name__ == '__main__':
  args = argparser()
  main(args)
