import sys
import time
from collections import defaultdict

import embodied
import numpy as np


def parallel(agent, replay, logger, make_env, num_envs, args):
  step = logger.step
  timer = embodied.Timer()
  timer.wrap('agent', agent, ['policy', 'train', 'report', 'save'])
  timer.wrap('replay', replay, ['add', 'save'])
  timer.wrap('logger', logger, ['write'])
  workers = []
  workers.append(embodied.distr.Thread(
      actor, step, agent, replay, logger, args.actor_addr, args))
  workers.append(embodied.distr.Thread(
      learner, step, agent, replay, logger, timer, args))
  if num_envs == 1:
    workers.append(embodied.distr.Thread(
        env, make_env, args.actor_addr, 0, args, timer))
  else:
    for i in range(num_envs):
      workers.append(embodied.distr.Process(
          env, make_env, args.actor_addr, i, args))
  embodied.distr.run(workers)


def actor(step, agent, replay, logger, actor_addr, args):
  metrics = embodied.Metrics()
  scalars = defaultdict(lambda: defaultdict(list))
  videos = defaultdict(lambda: defaultdict(list))
  should_log = embodied.when.Clock(args.log_every)

  _, initial = agent.policy(dummy_data(
      agent.agent.obs_space, (args.actor_batch,)))
  initial = embodied.treemap(lambda x: x[0], initial)
  allstates = defaultdict(lambda: initial)
  agent.sync()

  def callback(obs, env_addrs):
    states = [allstates[a] for a in env_addrs]
    states = embodied.treemap(lambda *xs: list(xs), *states)
    act, states = agent.policy(obs, states)
    act['reset'] = obs['is_last'].copy()
    for i, a in enumerate(env_addrs):
      allstates[a] = embodied.treemap(lambda x: x[i], states)

    trans = {**obs, **act}
    for i, a in enumerate(env_addrs):
      tran = {k: v[i].copy() for k, v in trans.items()}
      replay.add(tran.copy(), worker=a)
      [scalars[a][k].append(v) for k, v in tran.items() if v.size == 1]
      [videos[a][k].append(tran[k]) for k in args.log_keys_video]
    step.increment(args.actor_batch)

    for i, a in enumerate(env_addrs):
      if not trans['is_last'][i]:
        continue
      ep = {**scalars.pop(a), **videos.pop(a)}
      ep = {k: embodied.convert(v) for k, v in ep.items()}
      logger.add({
          'length': len(ep['reward']) - 1,
          'score': sum(ep['reward']),
      }, prefix='episode')
      stats = {}
      for key in args.log_keys_video:
        stats[f'policy_{key}'] = ep[key]
      metrics.add(stats, prefix='stats')

    if should_log():
      logger.add(metrics.result())

    return act

  print('[actor] Start server')
  embodied.BatchServer(actor_addr, args.actor_batch, callback).run()


def learner(step, agent, replay, logger, timer, args):
  logdir = embodied.Path(args.logdir)
  metrics = embodied.Metrics()
  should_log = embodied.when.Clock(args.log_every)
  should_save = embodied.when.Clock(args.save_every)
  should_sync = embodied.when.Every(args.sync_every)
  updates = embodied.Counter()

  checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt')
  checkpoint.step = step
  checkpoint.agent = agent
  checkpoint.replay = replay
  if args.from_checkpoint:
    checkpoint.load(args.from_checkpoint)
  checkpoint.load_or_save()

  dataset = agent.dataset(replay.dataset)
  state = None
  stats = dict(last_time=time.time(), last_step=int(step), batch_entries=0)
  while True:
    batch = next(dataset)
    outs, state, mets = agent.train(batch, state)
    metrics.add(mets)
    updates.increment()
    stats['batch_entries'] += batch['is_first'].size

    if should_sync(updates):
      agent.sync()

    if should_log():
      train = metrics.result()
      report = agent.report(batch)
      report = {k: v for k, v in report.items() if 'train/' + k not in train}
      logger.add(train, prefix='train')
      logger.add(report, prefix='report')
      logger.add(timer.stats(), prefix='timer')
      logger.add(replay.stats, prefix='replay')

      duration = time.time() - stats['last_time']
      actor_fps = (int(step) - stats['last_step']) / duration
      learner_fps = stats['batch_entries'] / duration
      logger.add({
          'actor_fps': actor_fps,
          'learner_fps': learner_fps,
          'train_ratio': learner_fps / actor_fps if actor_fps else np.inf,
      }, prefix='parallel')
      stats = dict(last_time=time.time(), last_step=int(step), batch_entries=0)

      logger.write(fps=True)

    if should_save():
      checkpoint.save()


def env(make_env, actor_addr, i, args, timer=None):
  # TODO: Optionally write NPZ episodes.
  print(f'[env{i}] Make env')
  env = make_env()
  if timer:
    timer.wrap('env', env, ['step'])
  actor = embodied.Client(actor_addr)
  act = {k: v.sample() for k, v in env.act_space.items()}
  done = False
  while True:
    act['reset'] = done
    obs = env.step(act)
    obs = {k: np.asarray(v) for k, v in obs.items()}
    done = obs['is_last']
    promise = actor(obs)
    try:
      act = promise()
    except RuntimeError:
      sys.exit(0)
    act = {k: v for k, v in act.items() if not k.startswith('log_')}


def dummy_data(spaces, batch_dims):
  # TODO: Get rid of this function by adding initial_policy_state() and
  # initial_train_state() to the agent API.
  spaces = list(spaces.items())
  data = {k: np.zeros(v.shape, v.dtype) for k, v in spaces}
  for dim in reversed(batch_dims):
    data = {k: np.repeat(v[None], dim, axis=0) for k, v in data.items()}
  return data
