# coding=utf-8
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training and evaluation in the online mode."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os
import time

from absl import logging

import gin
import numpy as np
import tensorflow as tf
from behavior_regularized_offline_rl.brac import bc_agent as bc_agent_module
from behavior_regularized_offline_rl.brac import brac_primal_agent as brac_primal_agent_module
from behavior_regularized_offline_rl.brac import dataset
from behavior_regularized_offline_rl.brac import policies
from behavior_regularized_offline_rl.brac import train_eval_utils
from behavior_regularized_offline_rl.brac import utils


def collect_batch_transition(collector, collect_steps):
  time_st = time.time()
  timed_at_step = 0
  steps_collected = 0
  log_freq = 5000
  while steps_collected < collect_steps:
    count = collector.collect_transition()
    steps_collected += count
    if (steps_collected % log_freq == 0 or steps_collected == collect_steps) and count > 0:
      steps_per_sec = ((steps_collected - timed_at_step) / (time.time() - time_st))
      timed_at_step = steps_collected
      time_st = time.time()
      logging.info('(%d/%d) steps collected at %.4g steps/s.', steps_collected, collect_steps, steps_per_sec)


@gin.configurable
def train_eval_recursive_brac(
    # Basic args.
    log_dir,
    env_name='HalfCheetah-v2',
    # Train and eval args.
    total_train_steps=10000,
    summary_freq=100,
    print_freq=1000,
    save_freq=int(1e8),
    eval_freq=5000,
    n_eval_episodes=20,
    # For saving a partially trained policy.
    eval_target=None,  # Target return value to stop training.
    eval_target_n=2,  # Stop after n consecutive evals above eval_target.
    # Agent train args.
    replay_buffer_size=int(1e6),
    model_params=(((200, 200),), 2),
    optimizers=(('adam', 0.001),),
    batch_size=256,
    weight_decays=(0.0,),
    update_freq=1,
    update_rate=0.005,
    discount=0.99,
    bc_train_steps=int(5e5),
    bc_model_params=((200, 200),),
    bc_optimizers=(('adam', 5e-4),),
    bc_batch_size=256,
    bc_weight_decays=(0.0,),
    bc_update_freq=1,
    bc_update_rate=0.005,
    bc_discount=0.99,
    data_collection_freq=2000,
    data_collection_steps=200000
    ):
  """Training a policy with online interaction."""
  # Create tf_env to get specs.
  tf_env = train_eval_utils.env_factory(env_name)
  tf_env_test = train_eval_utils.env_factory(env_name)
  observation_spec = tf_env.observation_spec()
  action_spec = tf_env.action_spec()

  # Initialize dataset.
  with tf.device('/cpu:0'):
    train_data = dataset.Dataset(
        observation_spec,
        action_spec,
        replay_buffer_size,
        circular=True,
        )
  # data_ckpt = tf.train.Checkpoint(data=train_data)
  # data_ckpt_name = os.path.join(log_dir, 'replay')

  time_st_total = time.time()

  # Prepare savers for models and results.
  train_summary_dir = os.path.join(log_dir, 'train')
  eval_summary_dir = os.path.join(log_dir, 'eval')
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_summary_dir)
  eval_summary_writers = collections.OrderedDict()
  agent_ckpt_name = os.path.join(log_dir, 'agent')
  eval_results = []

  # Train agent.
  logging.info('Start training ....')
  time_st = time.time()
  timed_at_step = 0
  target_partial_policy_saved = False

  deployments = -(-total_train_steps//data_collection_freq)
  for d in range(deployments):
    if d == 0:
      # Collect data from random policy.
      explore_policy = policies.ContinuousRandomPolicy(action_spec)
      logging.info('Collecting ramdom data ...')
      collector = train_eval_utils.DataCollector(tf_env, explore_policy, train_data)
      collect_batch_transition(collector, data_collection_steps)
    # behaviour cloning
    bc_agent_flags = utils.Flags(
          action_spec=action_spec,
          model_params=bc_model_params,
          optimizers=bc_optimizers,
          batch_size=bc_batch_size,
          weight_decays=bc_weight_decays,
          update_freq=bc_update_freq,
          update_rate=bc_update_rate,
          discount=bc_discount,
          train_data=train_data)
    bc_agent_args = bc_agent_module.Config(bc_agent_flags).agent_args
    bc_agent = bc_agent_module.Agent(**vars(bc_agent_args))
    for _ in range(bc_train_steps):
      bc_agent.train_step(offset=d*data_collection_steps)
    bc_checkpointer = bc_agent.behavior_checkpointer
    # train BRAC polocy
    if d == 0:
      agent_flags = utils.Flags(
            action_spec=action_spec,
            model_params=model_params,
            optimizers=optimizers,
            batch_size=batch_size,
            weight_decays=weight_decays,
            update_freq=update_freq,
            update_rate=update_rate,
            discount=discount,
            train_data=train_data)
      agent_args = brac_primal_agent_module.Config(agent_flags).agent_args
      agent = brac_primal_agent_module.Agent(behavior_checkpointer=bc_checkpointer, **vars(agent_args))
      for policy_key in agent.test_policies.keys():
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(os.path.join(eval_summary_dir, policy_key))
        eval_summary_writers[policy_key] = eval_summary_writer
      collector = train_eval_utils.DataCollector(tf_env, agent.online_policy, train_data)
    else:
      agent.update_behavior_checkpointer(bc_checkpointer)
    for _ in range(data_collection_freq):
      agent.train_step()
      step = agent.global_step
      if step % summary_freq == 0 or step == total_train_steps:
        agent.write_train_summary(train_summary_writer)
      if step % print_freq == 0 or step == total_train_steps:
        agent.print_train_info()
      if step % eval_freq == 0 or step == total_train_steps:
        time_ed = time.time()
        time_cost = time_ed - time_st
        logging.info('Training at %.4g steps/s.', (step - timed_at_step) / time_cost)
        eval_result, eval_infos = train_eval_utils.eval_policies(tf_env_test, agent.test_policies, n_eval_episodes)
        eval_results.append([step] + eval_result)
        # Cecide whether to save a partially trained policy based on current model
        # performance.
        if (eval_target is not None and len(eval_results) >= eval_target_n and not target_partial_policy_saved):
          evals_ = list([eval_results[-(i + 1)][1] for i in range(eval_target_n)])
          evals_ = np.array(evals_)
        logging.info('Testing at step %d:', step)
        for policy_key, policy_info in eval_infos.items():
          logging.info(utils.get_summary_str(step=None, info=policy_info, prefix=policy_key + ': '))
          utils.write_summary(eval_summary_writers[policy_key], step, policy_info)
        time_st = time.time()
        timed_at_step = step
      if step % save_freq == 0:
        agent.save(agent_ckpt_name + '-' + str(step))
    if d < deployments - 1:
      collect_batch_transition(collector, data_collection_steps)
  # Final save after training.
  agent.save(agent_ckpt_name + '_final')
  # data_ckpt.write(data_ckpt_name + '_final')
  time_cost = time.time() - time_st_total
  logging.info('Training finished, time cost %.4gs.', time_cost)
  return np.array(eval_results)
