import os

from utils import log_trajectory_statistics
from envs.envs import (ExpertInvertedPendulumEnv, ExpertInvertedDoublePendulumEnv)
from envs.more_envs import CustomReacher2Env, CustomReacher3Env

from samplers import Sampler
from utils import save_expert_trajectories, save_learner_trajectories

import argparse

def parse_arguments():
    parser = argparse.ArgumentParser(description='Collecting random data')
    parser.add_argument('--env_name', help="The random env name.", default='InvertedPendulum')
    parser.add_argument('--max_step', help="Max length of buffer size.", default=50000)

    args = parser.parse_args()
    args.algo = 'DIFF-IL'
    return args

def collect_prior_data(realm_name, max_timesteps=40000, prior_samples_location='prior_data'):
    """Collect and save prior visual observations for an environment realm.

    Parameters
    ----------
    realm_name : Environment realm to collect the visual observations.
    max_timesteps : Maximum number of visual observations to collect, default is 40000.
    prior_samples_location : Folder to save the prior visual observations collected.
    """
    if realm_name == 'InvertedPendulum':
        prior_envs = [ExpertInvertedPendulumEnv(), ExpertInvertedPendulumEnv(),
                      ExpertInvertedDoublePendulumEnv(), ExpertInvertedDoublePendulumEnv()]
        prior_env_names = ['InvertedPendulum-v2', 'InvertedPendulum-v2_learner',
                           'InvertedDoublePendulum-v2', 'InvertedDoublePendulum-v2_learner']
        episode_limit = [50,1000,50,1000]

    elif realm_name == 'Reacher':
        prior_envs = [CustomReacher2Env(), CustomReacher2Env(), CustomReacher3Env(), CustomReacher3Env()]
        prior_env_names = ['Reacher2-v2','Reacher2-v2_learner',
                           'Reacher3-v2','Reacher3-v2_learner']
        episode_limit = [50,50,50,50]

        # ================================================== DMC
    elif realm_name == 'DMCartPoleSwingUp':
        from envs.dmcontrol_envs import DMCartPoleSwingUpEnv
        prior_envs = [DMCartPoleSwingUpEnv(),DMCartPoleSwingUpEnv()]
        prior_env_names = ['DMCartPoleSwingUp','DMCartPoleSwingUp_learner']
        episode_limit = [200,1000]
    elif realm_name == 'DMPendulum':
        from envs.dmcontrol_envs import DMPendulumEnv
        prior_envs = [DMPendulumEnv(),DMPendulumEnv()]
        prior_env_names = ['DMPendulum','DMPendulum_learner']
        episode_limit = [200,1000]
    elif realm_name == 'DMAcrobot':
        from envs.dmcontrol_envs import DMAcrobotEnv
        prior_envs = [DMAcrobotEnv(), DMAcrobotEnv()]
        prior_env_names = ['DMAcrobot', 'DMAcrobot_learner']
        episode_limit = [200,1000]
    elif realm_name == 'DMHopper':
        from envs.dmcontrol_envs import DMHopperEnv
        prior_envs = [DMHopperEnv(),DMHopperEnv()]
        prior_env_names = ['DMHopper','DMHopper_learner']
        episode_limit = [200,200]
    elif realm_name == 'DMWalker':
        from envs.dmcontrol_envs import DMWalkerEnv
        prior_envs = [DMWalkerEnv(), DMWalkerEnv()]
        prior_env_names = ['DMWalker', 'DMWalker_learner']
        episode_limit = [200,200]
    elif realm_name == 'DMCheetah':
        from envs.dmcontrol_envs import DMCheetahEnv
        prior_envs = [DMCheetahEnv(),DMCheetahEnv()]
        prior_env_names = ['DMCheetah','DMCheetah_learner']
        episode_limit = [200,200]

    elif realm_name == 'Robot_Manipulation':
        from envs.manipulation_envs import PusherEnv, PusherHumanSimEnv, ReachEnv, ReachHumanSimEnv
        prior_envs = [PusherEnv(),PusherHumanSimEnv(),ReachEnv(),ReachHumanSimEnv(),PusherEnv(),PusherHumanSimEnv(),ReachEnv(),ReachHumanSimEnv() ]
        prior_env_names = ['Pusher-v2','PusherHumanSim-v2', 'Reach-v2','ReachHumanSim-v2', 'Pusher-v2_learner','PusherHumanSim-v2_learner', 'Reach-v2_learner','ReachHumanSim-v2_learner']
        episode_limit = [200,200,200,200,200,200,200,200]

    else:
        print('Please select one of the implemented environments')
        raise NotImplementedError


    for env, env_name, epi_limit in zip(prior_envs, prior_env_names, episode_limit):
        episodes_n = int(max_timesteps // epi_limit)
        saver_sampler = Sampler(env, episode_limit=epi_limit,
                                init_random_samples=0, visual_env=True)

        traj = saver_sampler.sample_test_trajectories(None, 0.0, episodes_n, False, get_ims=True)
        log_trajectory_statistics(traj['ret'])
        os.makedirs(prior_samples_location + '/' + env_name, exist_ok=True)
        save_expert_trajectories(traj, env_name, prior_samples_location,
                                 visual_data=True)
    print('Prior trajectories successfully saved.')


def main():
    # Parse arguments
    args = parse_arguments()

    # Run experiment
    collect_prior_data(args.env_name, args.max_step)


if __name__ == "__main__":
    main()