import numpy as np
from gym_wrapper import BaseEnv
import random

def CL_envs_func(env_name, seed):
	if env_name == "all":
		all_envs = ["breakout", "space_invaders", "freeway"]
		sample_env = np.random.choice(all_envs)
		return BaseEnv(sample_env, seed=seed, use_minimal_action_set=False, use_minimal_observation=False)

	else:
		raise NotImplementedError

def CL_envs_func_replacement(seq, game_id, seed, evaluation=False):
	if evaluation:
		all_envs = ["breakout", "space_invaders", "freeway"]
		assert game_id in [0, 1, 2]
		sample_env = all_envs[game_id]
		return BaseEnv(sample_env, seed=seed, use_minimal_action_set=False, use_minimal_observation=False)

	# sample with replacement: each sequence is under sampled under a difference seed
	envs = [['breakout', 'space_invaders', 'breakout', 'space_invaders', 'space_invaders', 'freeway', 'breakout'] ,
			['space_invaders', 'breakout', 'breakout', 'space_invaders', 'space_invaders', 'breakout', 'breakout'] ,
			['breakout', 'space_invaders', 'breakout', 'freeway', 'freeway', 'breakout', 'freeway'] ,
			['freeway', 'breakout', 'space_invaders', 'breakout', 'breakout', 'breakout', 'space_invaders'] ,
			['freeway', 'freeway', 'space_invaders', 'space_invaders', 'breakout', 'breakout', 'freeway'] ,
			['freeway', 'space_invaders', 'freeway', 'freeway', 'breakout', 'space_invaders', 'breakout'] ,
			['freeway', 'space_invaders', 'breakout', 'freeway', 'space_invaders', 'freeway', 'breakout'] ,
			['breakout', 'space_invaders', 'freeway', 'breakout', 'space_invaders', 'freeway', 'breakout'] ,
			['breakout', 'space_invaders', 'space_invaders', 'space_invaders', 'freeway', 'breakout', 'breakout'] ,
			['freeway', 'breakout', 'freeway', 'space_invaders', 'freeway', 'breakout', 'freeway'] ,
			['space_invaders', 'space_invaders', 'breakout', 'breakout', 'space_invaders', 'breakout', 'space_invaders'] ,
			['space_invaders', 'breakout', 'space_invaders', 'space_invaders', 'breakout', 'space_invaders', 'freeway'] ,
			['freeway', 'space_invaders', 'space_invaders', 'freeway', 'breakout', 'breakout', 'freeway'] ,
			['freeway', 'breakout', 'freeway', 'breakout', 'freeway', 'freeway', 'breakout'] ,
			['breakout', 'breakout', 'freeway', 'space_invaders', 'freeway', 'freeway', 'breakout'] ,
			['breakout', 'space_invaders', 'breakout', 'space_invaders', 'breakout', 'breakout', 'space_invaders'] ,
			['space_invaders', 'freeway', 'space_invaders', 'space_invaders', 'space_invaders', 'breakout', 'space_invaders'] ,
			['space_invaders', 'freeway', 'freeway', 'space_invaders', 'breakout', 'space_invaders', 'freeway'] ,
			['freeway', 'breakout', 'space_invaders', 'freeway', 'space_invaders', 'freeway', 'freeway'] ,
			['space_invaders', 'freeway', 'space_invaders', 'freeway', 'breakout', 'breakout', 'freeway'] ,
			['freeway', 'breakout', 'freeway', 'space_invaders', 'breakout', 'freeway', 'freeway'] ,
			['space_invaders', 'breakout', 'breakout', 'breakout', 'breakout', 'breakout', 'breakout'] ,
			['space_invaders', 'breakout', 'breakout', 'breakout', 'breakout', 'freeway', 'freeway'] ,
			['freeway', 'breakout', 'space_invaders', 'freeway', 'breakout', 'space_invaders', 'breakout'] ,
			['freeway', 'breakout', 'space_invaders', 'space_invaders', 'space_invaders', 'breakout', 'breakout'] ,
			['breakout', 'freeway', 'freeway', 'space_invaders', 'freeway', 'breakout', 'freeway'] ,
			['space_invaders', 'freeway', 'freeway', 'breakout', 'space_invaders', 'freeway', 'space_invaders'] ,
			['breakout', 'breakout', 'breakout', 'breakout', 'space_invaders', 'breakout', 'space_invaders'] ,
			['space_invaders', 'space_invaders', 'space_invaders', 'freeway', 'breakout', 'breakout', 'breakout'] ,
			['space_invaders', 'breakout', 'space_invaders', 'freeway', 'breakout', 'breakout', 'space_invaders'] ,
			]
	env_seq = envs[seq]
	single_env = env_seq[game_id]
	return BaseEnv(single_env, seed=seed, use_minimal_action_set=False, use_minimal_observation=False)




def CL_envs_func_new(seq, game_id, seed):
	all_envs = ["breakout", "space_invaders", "freeway"]
	# all_envs is generated by generate_envs
	envs = [
		['breakout', 'space_invaders', 'freeway', 'space_invaders', 'breakout', 'freeway', 'space_invaders'],
		['space_invaders', 'freeway', 'space_invaders', 'breakout', 'space_invaders', 'freeway', 'space_invaders'],
		['freeway', 'breakout', 'freeway', 'space_invaders', 'breakout', 'space_invaders', 'freeway'],
		['breakout', 'space_invaders', 'breakout', 'freeway', 'space_invaders', 'breakout', 'space_invaders'],
		['space_invaders', 'breakout', 'space_invaders', 'freeway', 'space_invaders', 'freeway', 'breakout'],
		['freeway', 'space_invaders', 'breakout', 'freeway', 'space_invaders', 'breakout', 'space_invaders'],
		['breakout', 'space_invaders', 'freeway', 'space_invaders', 'breakout', 'space_invaders', 'freeway'],
		['space_invaders', 'freeway', 'breakout', 'freeway', 'breakout', 'freeway', 'space_invaders'],
		['freeway', 'space_invaders', 'breakout', 'freeway', 'space_invaders', 'freeway', 'breakout'],
		['breakout', 'space_invaders', 'breakout', 'space_invaders', 'freeway', 'breakout', 'space_invaders'],
	]
	env_seq = envs[seq]
	single_env = env_seq[game_id]
	return BaseEnv(single_env, seed=seed, use_minimal_action_set=False, use_minimal_observation=False)




def generate_envs():
	all_envs = ["breakout", "space_invaders", "freeway"]

	env_list = []
	for seed in range(10):
		np.random.seed(seed)
		random.seed(seed)
		subenv_list = []
		sample_env = all_envs[seed % len(all_envs)]
		subenv_list.append(sample_env)
		for env in range(6):
			temp_list = [x for x in all_envs if x != sample_env]
			sample_env = np.random.choice(temp_list)
			subenv_list.append(sample_env)
		env_list.append(subenv_list)
	return env_list

if __name__ == '__main__':
	env = CL_envs_func_new(0, 1, 1)
	print(env)