# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training and evaluation in the offline mode."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import time
import sys
import os

from absl import logging

import gin
import gym
import numpy as np
import tensorflow as tf0
import tensorflow.compat.v1 as tf

from behavior_regularized_offline_rl.brac import dataset
from behavior_regularized_offline_rl.brac import train_eval_utils
from behavior_regularized_offline_rl.brac import utils

from gym.wrappers import time_limit

from tf_agents.environments import tf_py_environment
from tf_agents.environments import gym_wrapper

import h5py


def get_data(dataset, data_path, isMediumExpert):
    num = int(1e5)
    if not isMediumExpert:
        dataset['observations'] = dataset['observations'][:num]
        dataset['actions'] = dataset['actions'][:num]
        dataset['rewards'] = dataset['rewards'][:num]
        dataset['terminals'] = dataset['terminals'][:num]
    else:
        dataset['observations'] = np.concatenate((dataset['observations'][:num], dataset['observations'][-num:]),
                                                 axis=0)
        dataset['actions'] = np.concatenate((dataset['actions'][:num], dataset['actions'][-num:]), axis=0)
        dataset['rewards'] = np.concatenate((dataset['rewards'][:num], dataset['rewards'][-num:]), axis=0)
        dataset['terminals'] = np.concatenate((dataset['terminals'][:num], dataset['terminals'][-num:]), axis=0)

    if data_path:
        data = h5py.File(data_path, 'r')
        dataset['observations'] = np.concatenate((dataset['observations'], data['observations']), axis=0)
        dataset['actions'] = np.concatenate((dataset['actions'], data['actions']), axis=0)
        dataset['rewards'] = np.concatenate((dataset['rewards'], np.squeeze(data['rewards'])), axis=0)
        dataset['terminals'] = np.concatenate((dataset['terminals'], data['terminals']), axis=0)
    return dataset


def get_offline_data(tf_env, data_path, isMediumExpert):
    gym_env = tf_env.pyenv.envs[0]
    # offline_dataset = gym_env.unwrapped.get_dataset()
    offline_dataset = gym_env.get_dataset()

    data_process = True
    if data_process:
        offline_dataset = get_data(offline_dataset, data_path, isMediumExpert)

    dataset_size = len(offline_dataset['observations'])
    print(dataset_size)
    tf_dataset = dataset.Dataset(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        size=dataset_size)
    observation_dtype = tf_env.observation_spec().dtype
    action_dtype = tf_env.action_spec().dtype

    offline_dataset['terminals'] = np.squeeze(offline_dataset['terminals'])
    offline_dataset['rewards'] = np.squeeze(offline_dataset['rewards'])
    nonterminal_steps, = np.where(
        np.logical_and(
            np.logical_not(offline_dataset['terminals']),
            np.arange(dataset_size) < dataset_size - 1))
    logging.info('Found %d non-terminal steps out of a total of %d steps.' % (
        len(nonterminal_steps), dataset_size))

    s1 = tf.convert_to_tensor(offline_dataset['observations'][nonterminal_steps],
                              dtype=observation_dtype)
    s2 = tf.convert_to_tensor(offline_dataset['observations'][nonterminal_steps + 1],
                              dtype=observation_dtype)
    a1 = tf.convert_to_tensor(offline_dataset['actions'][nonterminal_steps],
                              dtype=action_dtype)
    a2 = tf.convert_to_tensor(offline_dataset['actions'][nonterminal_steps + 1],
                              dtype=action_dtype)
    discount = tf.convert_to_tensor(
        1. - offline_dataset['terminals'][nonterminal_steps + 1],
        dtype=tf.float32)
    reward = tf.convert_to_tensor(offline_dataset['rewards'][nonterminal_steps],
                                  dtype=tf.float32)

    transitions = dataset.Transition(
        s1, s2, a1, a2, discount, reward)

    tf_dataset.add_transitions(transitions)
    return tf_dataset


def env_factory(env_name):
    gym_env = gym.make(env_name)
    gym_spec = gym.spec(env_name)
    if gym_spec.max_episode_steps in [0, None]:  # Add TimeLimit wrapper.
        gym_env = time_limit.TimeLimit(gym_env, max_episode_steps=1000)

    tf_env = tf_py_environment.TFPyEnvironment(
        gym_wrapper.GymWrapper(gym_env))
    return tf_env


@gin.configurable
def train_eval_offline(
        # Basic args.
        log_dir,
        data_file,
        agent_module,
        env_name='HalfCheetah-v2',
        n_train=int(1e6),
        shuffle_steps=0,
        seed=0,
        use_seed_for_data=False,
        # Train and eval args.
        total_train_steps=int(1e6),
        summary_freq=100,
        print_freq=1000,
        save_freq=int(2e4),
        eval_freq=5000,
        n_eval_episodes=20,
        # Agent args.
        model_params=(((200, 200),), 2),
        behavior_ckpt_file=None,
        value_penalty=True,
        alpha=1.0,
        # model_params=((200, 200),),
        optimizers=(('adam', 0.001),),
        batch_size=256,
        # batch_size=4,
        weight_decays=(0.0,),
        update_freq=1,
        update_rate=0.005,
        discount=0.99,
        data_path=None,
        isMediumExpert=False,
):
    """Training a policy with a fixed dataset."""
    # Create tf_env to get specs.
    print('[train_eval_offline.py] env_name=', env_name)
    print('[train_eval_offline.py] data_file=', data_file)
    print('[train_eval_offline.py] agent_module=', agent_module)
    print('[train_eval_offline.py] model_params=', model_params)
    print('[train_eval_offline.py] optimizers=', optimizers)
    print('[train_eval_offline.py] bckpt_file=', behavior_ckpt_file)
    print('[train_eval_offline.py] value_penalty=', value_penalty)

    tf_env = env_factory(env_name)
    observation_spec = tf_env.observation_spec()
    action_spec = tf_env.action_spec()

    # Prepare data.
    full_data = get_offline_data(tf_env, data_path, isMediumExpert)

    # Split data.
    n_train = min(n_train, full_data.size)
    logging.info('n_train %s.', n_train)
    if use_seed_for_data:
        rand = np.random.RandomState(seed)
    else:
        rand = np.random.RandomState(0)
    shuffled_indices = utils.shuffle_indices_with_steps(
        n=full_data.size, steps=shuffle_steps, rand=rand)
    train_indices = shuffled_indices[:n_train]
    train_data = full_data.create_view(train_indices)

    # Create agent.
    agent_flags = utils.Flags(
        observation_spec=observation_spec,
        action_spec=action_spec,
        model_params=model_params,
        optimizers=optimizers,
        batch_size=batch_size,
        weight_decays=weight_decays,
        update_freq=update_freq,
        update_rate=update_rate,
        discount=discount,
        train_data=train_data)
    agent_args = agent_module.Config(agent_flags).agent_args
    my_agent_arg_dict = {}
    for k in vars(agent_args):
        my_agent_arg_dict[k] = vars(agent_args)[k]
    if 'brac_primal' in agent_module.__name__:
        my_agent_arg_dict['behavior_ckpt_file'] = behavior_ckpt_file
        my_agent_arg_dict['value_penalty'] = value_penalty
        my_agent_arg_dict['alpha'] = alpha
    print('agent:', agent_module.__name__)
    print('agent_args:', my_agent_arg_dict)
    # agent = agent_module.Agent(**vars(agent_args))
    agent = agent_module.Agent(**my_agent_arg_dict)
    agent_ckpt_name = os.path.join(log_dir, 'agent')

    # Restore agent from checkpoint if there exists one.
    if tf.io.gfile.exists('{}.index'.format(agent_ckpt_name)):
        logging.info('Checkpoint found at %s.', agent_ckpt_name)
        agent.restore(agent_ckpt_name)

    # Train agent.
    train_summary_dir = os.path.join(log_dir, 'train')
    eval_summary_dir = os.path.join(log_dir, 'eval')
    train_summary_writer = tf0.compat.v2.summary.create_file_writer(
        train_summary_dir)
    eval_summary_writers = collections.OrderedDict()
    for policy_key in agent.test_policies.keys():
        eval_summary_writer = tf0.compat.v2.summary.create_file_writer(
            os.path.join(eval_summary_dir, policy_key))
        eval_summary_writers[policy_key] = eval_summary_writer
    eval_results = []

    time_st_total = time.time()
    time_st = time.time()
    step = agent.global_step
    timed_at_step = step
    while step < total_train_steps:
        agent.train_step()
        step = agent.global_step
        if step % summary_freq == 0 or step == total_train_steps:
            agent.write_train_summary(train_summary_writer)
        if step % print_freq == 0 or step == total_train_steps:
            agent.print_train_info()
        if step % eval_freq == 0 or step == total_train_steps:
            time_ed = time.time()
            time_cost = time_ed - time_st
            logging.info(
                'Training at %.4g steps/s.', (step - timed_at_step) / time_cost)
            eval_result, eval_infos = train_eval_utils.eval_policies(
                tf_env, agent.test_policies, n_eval_episodes)
            eval_results.append([step] + eval_result)
            with open(os.path.join(log_dir, 'results.txt'), 'a') as logfile:
                logfile.write(str(eval_result) + '\n')
            logging.info('Testing at step %d:', step)
            for policy_key, policy_info in eval_infos.items():
                logging.info(utils.get_summary_str(
                    step=None, info=policy_info, prefix=policy_key + ': '))
                utils.write_summary(eval_summary_writers[policy_key], step, policy_info)
            time_st = time.time()
            timed_at_step = step
        if step % save_freq == 0:
            agent.save(agent_ckpt_name)
            logging.info('Agent saved at %s.', agent_ckpt_name)

    agent.save(agent_ckpt_name)
    time_cost = time.time() - time_st_total
    logging.info('Training finished, time cost %.4gs.', time_cost)
    return np.array(eval_results)
