import os

from utils import log_trajectory_statistics
from envs.envs import (ExpertInvertedPendulumEnv, ExpertInvertedDoublePendulumEnv)

from envs.more_envs import CustomReacher2Env, CustomReacher3Env
from samplers import Sampler
from utils import save_expert_trajectories
import tensorflow as tf
from train_expert import train_expert

import argparse

def parse_arguments():
    parser = argparse.ArgumentParser(description='Collecting Expert data')
    parser.add_argument('--env_name', help="The expert env name.", default='InvertedDoublePendulum-v2')
    parser.add_argument('--max_step', help="Max length of buffer size.", default=50000)
    parser.add_argument('--gpu_id', help="Select positive number if using multiple gpus", type=int, default=0)

    args = parser.parse_args()
    args.algo = 'DIFF-IL'
    return args

def collect_expert_data(agent, env_name, max_timesteps=40000, expert_samples_location='expert_data'):
    """Collect and save demonstrations with trained expert agent.

    Parameters
    ----------
    agent : Trained expert agent.
    env_name : Source environment to collect the demonstrations.
    max_timesteps : Maximum number of visual observations to collect, default is 40000.
    expert_samples_location : Folder to save the expert demonstrations collected.
    """
    print("Collecting Expert datas from expert policy. it takes some times.")
    if env_name == 'InvertedDoublePendulum-v2':
        expert_env = ExpertInvertedDoublePendulumEnv()
        episode_limit = 1000
    elif env_name == 'InvertedPendulum-v2':
        expert_env = ExpertInvertedPendulumEnv()
        episode_limit = 1000
    elif env_name == 'Reacher2-v2':
        expert_env = CustomReacher2Env()
        episode_limit = 50
    elif env_name == 'Reacher3-v2':
        expert_env = CustomReacher3Env()
        episode_limit = 50
    # ================================================== DMC
    elif env_name == 'DMPendulum':
        from envs.dmcontrol_envs import DMPendulumEnv
        expert_env = DMPendulumEnv()
        episode_limit = 1000
    elif env_name == 'DMCartPoleSwingUp':
        from envs.dmcontrol_envs import DMCartPoleSwingUpEnv
        expert_env = DMCartPoleSwingUpEnv()
        episode_limit = 1000
    elif env_name == 'DMAcrobot':
        from envs.dmcontrol_envs import DMAcrobotEnv
        expert_env = DMAcrobotEnv()
        episode_limit = 1000
    elif env_name == 'DMCheetah':
        from envs.dmcontrol_envs import DMCheetahEnv
        expert_env = DMCheetahEnv()
        episode_limit = 200
    elif env_name == 'DMWalker':
        from envs.dmcontrol_envs import DMWalkerEnv
        expert_env = DMWalkerEnv()
        episode_limit = 200
    elif env_name == 'DMHopper':
        from envs.dmcontrol_envs import DMHopperEnv
        expert_env = DMHopperEnv()
        episode_limit = 200
    elif env_name == 'Pusher-v2':
        from envs.manipulation_envs import PusherEnv
        expert_env = PusherEnv()
        episode_limit = 200
    elif env_name == 'PusherHumanSim-v2':
        from envs.manipulation_envs import PusherHumanSimEnv
        expert_env = PusherHumanSimEnv()
        episode_limit = 200
    elif env_name == 'Reach-v2':
        from envs.manipulation_envs import ReachEnv
        expert_env = ReachEnv()
        episode_limit = 200
    elif env_name == 'ReachHumanSim-v2':
        from envs.manipulation_envs import ReachHumanSimEnv
        expert_env = ReachHumanSimEnv()
        episode_limit = 200


    else:
        print('Please select one of the implemented environments')
        raise NotImplementedError

    episodes_n = int(max_timesteps // episode_limit)

    saver_sampler = Sampler(expert_env, episode_limit=episode_limit,
                            init_random_samples=0, visual_env=True)

    print("Save it from buffers to local storage")
    traj = saver_sampler.sample_test_trajectories(agent, 0.0, episodes_n, False, get_ims=True)
    log_trajectory_statistics(traj['ret'])
    os.makedirs(expert_samples_location + '/' + env_name)
    save_expert_trajectories(traj, env_name, expert_samples_location,
                             visual_data=True)
    print('Expert trajectories successfully saved.')


def main():
    # Parse arguments
    args = parse_arguments()

    gpu_id = args.gpu_id
    gpu_allow_growth = True
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if len(gpus) > 0:
        tf.config.experimental.set_visible_devices(gpus[gpu_id], 'GPU')
        if gpu_allow_growth:
            try:
                for i in range(len(gpus)):
                    tf.config.experimental.set_memory_growth(gpus[i], True)
            except RuntimeError as e:
                print(e)

    expert_agent = train_expert(env_name=args.env_name)

    collect_expert_data(expert_agent, args.env_name, args.max_step)


if __name__ == "__main__":
    main()