import numpy as np

import gym
from gym.envs.registration import register
from shapeFollow_gym.path_follow_2d import PathFollowBaseEnv, DirectionWrapper, ShapeChoiceWrapper, TriShapeWrapper

from shapeFollow_gym.path_follow_2d.shape_gen import CirclePath, SquarePath, TrianglePath
from shapeFollow_gym.path_follow_2d.shape_gen import compute_square_reward, compute_max1_positive_reward, compute_max1_scaled_mse_reward, compute_max1_scaled_mae_reward

max_steps = 90
phase_step = (360/max_steps)/180*np.pi

################################################################################################
# BASIC SHAPES
################################################################################################

register(
	id='CircleFollow2D-v0',
	entry_point='shapeFollow_gym.path_follow_2d:PathFollowBaseEnv',
	max_episode_steps=max_steps,
	kwargs={"shape_gen":CirclePath(radius=5, reward_fn=compute_max1_positive_reward),
			"phase_step":phase_step,
			"max_steps":max_steps,
	}
)

register(
	id='CircleFollow2D-v1',
	entry_point='shapeFollow_gym.path_follow_2d:PathFollowBaseEnv',
	max_episode_steps=max_steps,
	kwargs={"shape_gen":CirclePath(radius=5, reward_fn=compute_max1_positive_reward),
			"phase_step":-phase_step,
			"max_steps":max_steps,
	}
)

register(
	id='CircleFollow2D_directional-v0',
	entry_point='shapeFollow_gym.path_follow_2d:DirectionWrapper',
	max_episode_steps=max_steps,
	kwargs={"env": gym.make('CircleFollow2D-v0')}
)

register(
	id='SquareFollow2D-v0',
	entry_point='shapeFollow_gym.path_follow_2d:PathFollowBaseEnv',
	max_episode_steps=max_steps,
	kwargs={"shape_gen":SquarePath(side_len=10, reward_fn=compute_max1_positive_reward),
			"phase_step":phase_step,
			"max_steps":max_steps,
	}
)

register(
	id='SquareFollow2D-v1',
	entry_point='shapeFollow_gym.path_follow_2d:PathFollowBaseEnv',
	max_episode_steps=max_steps,
	kwargs={"shape_gen":SquarePath(side_len=10, reward_fn=compute_max1_positive_reward),
			"phase_step":-phase_step,
			"max_steps":max_steps,
	}
)

register(
	id='SquareFollow2D_directional-v0',
	entry_point='shapeFollow_gym.path_follow_2d:DirectionWrapper',
	max_episode_steps=max_steps,
	kwargs={"env": gym.make('SquareFollow2D-v0')}
)

register(
	id='SquareFollow2D_directional-v1',
	entry_point='shapeFollow_gym.path_follow_2d:DirectionWrapper',
	max_episode_steps=max_steps,
	kwargs={"env": gym.make('SquareFollow2D-v1')}
)

register(
	id='TriangleFollow2D-v0',
	entry_point='shapeFollow_gym.path_follow_2d:PathFollowBaseEnv',
	max_episode_steps=max_steps,
	kwargs={"shape_gen":TrianglePath(height=10, reward_fn=compute_max1_positive_reward),
			"phase_step":phase_step,
			"max_steps":max_steps,
	}
)

register(
	id='SquareAndCircleFollow2D-v0',
	entry_point='shapeFollow_gym.path_follow_2d:ShapeChoiceWrapper',
	max_episode_steps=max_steps,
	kwargs={"env1": gym.make('CircleFollow2D-v0'),
			"env2": gym.make('SquareFollow2D-v0')}
)

register(
	id='SquareAndCircleFollow2D_directional-v0',
	entry_point='shapeFollow_gym.path_follow_2d:ShapeChoiceWrapper',
	max_episode_steps=max_steps,
	kwargs={"env1": gym.make('CircleFollow2D_directional-v0'),
			"env2": gym.make('SquareFollow2D_directional-v1')}
)

register(
	id='TriShapeFollow2D-v0',
	entry_point='shapeFollow_gym.path_follow_2d:TriShapeWrapper',
	max_episode_steps=max_steps,
	kwargs={"env1": gym.make('CircleFollow2D-v1'),
			"env2": gym.make('SquareFollow2D-v0'),
			"env3": gym.make('TriangleFollow2D-v0')}
)

register(
	id='TriShapeFollow2D-v1',
	entry_point='shapeFollow_gym.path_follow_2d:TriShapeWrapper',
	max_episode_steps=max_steps,
	kwargs={"env1": gym.make('CircleFollow2D-v0'),
			"env2": gym.make('SquareFollow2D-v0'),
			"env3": gym.make('TriangleFollow2D-v0')}
)
################################################################################################