import gym
import numpy as np

import collections
import pickle

import d4rl
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--type', type=int, default=2)
parser.add_argument('--regular_ant', type=int, default=1)
args = parser.parse_args()

datasets = []

### added:
all_dataset_types = ['medium', 'medium-replay', 'expert']
all_envs_name = []

if args.type != 1:
	all_envs_name = ['halfcheetah', 'hopper', 'walker2d']

#all_dataset_types = ['medium', 'medium-replay', 'expert']
###

for env_name in all_envs_name:
	for dataset_type in all_dataset_types:
		name = f'{env_name}-{dataset_type}-v2'
		print("#" * 25 + f"  Downloading {name}  " + "#" * 25)
		env = gym.make(name)
		dataset = env.get_dataset()

		N = dataset['rewards'].shape[0]
		data_ = collections.defaultdict(list)

		use_timeouts = False
		if 'timeouts' in dataset:
			use_timeouts = True

		episode_step = 0
		paths = []
		for i in range(N):
			done_bool = bool(dataset['terminals'][i])
			if use_timeouts:
				final_timestep = dataset['timeouts'][i]
			else:
				final_timestep = (episode_step == 1000-1)
			for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']:
				data_[k].append(dataset[k][i])
			if done_bool or final_timestep:
				episode_step = 0
				episode_data = {}
				for k in data_:
					episode_data[k] = np.array(data_[k])
				paths.append(episode_data)
				data_ = collections.defaultdict(list)
			episode_step += 1

		returns = np.array([np.sum(p['rewards']) for p in paths])
		num_samples = np.sum([p['rewards'].shape[0] for p in paths])
		print(f'Number of samples collected: {num_samples}')
		print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}')

		with open(f'{name}.pkl', 'wb') as f:
			pickle.dump(paths, f)

print(f"other specific envs:")
specific_envs = []
if args.type != 0:
	specific_envs =["antmaze-umaze-diverse-v2",
					"antmaze-umaze-v2",
					"antmaze-medium-diverse-v2",
					"antmaze-medium-play-v2",
					"antmaze-large-diverse-v2",
					"antmaze-large-play-v2"]

for name in specific_envs:
	print("#" * 25 + f"  Downloading {name}  " + "#" * 25)
	env = gym.make(name)
	dataset = env.get_dataset()

	N = dataset['rewards'].shape[0]
	data_ = collections.defaultdict(list)
	ishardmaze = False
	if ("diverse" in name or "play" in name) and args.regular_ant==0:
		ishardmaze = True
		print(("Using diverse/play method - reward on timeout"))

	use_timeouts = False
	if 'timeouts' in dataset:
		print("Using timeouts")
		use_timeouts = True

	episode_step = 0
	paths = []
	for i in range(N):
		done_bool = bool(dataset['terminals'][i]) and not ishardmaze
		if use_timeouts:
			final_timestep = dataset['timeouts'][i]
		else:
			final_timestep = (episode_step == 1000 - 1)
		#for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']:
		for k in ['observations', 'actions', 'terminals']: # removed rewards to handle antmaze
			data_[k].append(dataset[k][i])
		if not ishardmaze:
			data_['rewards'].append(dataset['rewards'][i])
		else:
			data_['rewards'].append(int(dataset['timeouts'][i]))
			data_['infos/goal'].append(dataset['infos/goal'][i])
		if done_bool or final_timestep:
			episode_step = 0
			episode_data = {}
			for k in data_:
				episode_data[k] = np.array(data_[k])
			paths.append(episode_data)
			data_ = collections.defaultdict(list)
		episode_step += 1

	returns = np.array([np.sum(p['rewards']) for p in paths])
	num_samples = np.sum([p['rewards'].shape[0] for p in paths])
	print(f'Number of samples collected: {num_samples}')
	print(
		f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}')

	with open(f'{name}.pkl', 'wb') as f:
		pickle.dump(paths, f)
