import numpy as np
from Environments.acrobot import AcrobotEnv
from Environments.FourRoomsWorld import FourRoomGridWorld
from Environments.FourRoomsWorldNoisy import FourRoomsGridWolrdNoisy
from Environments.FourRoomsWorldDoorKey import FourRoomsGridWolrdDoorKey
from Environments.Maze import MazeGridWorld
from Environments.pinball import PinballEnvironment
from Environments.puddle_world import PuddleEnv
import gym
from gym_minigrid.wrappers import *
from Algorithms.dqn import DQN
from Algorithms.dqn_master_slave import DQN_master_slave
from Algorithms.dqn_gradient_alignment import DQN_gradient_alignment


learner_dict = {'DQN': DQN, 'DQN_master_slave': DQN_master_slave, 'DQN_gradient_alignment': DQN_gradient_alignment}

acrobot_env = AcrobotEnv()
four_rooms_env = FourRoomGridWorld(stochasticity_fraction=0.0)
four_rooms_noisy_env = FourRoomsGridWolrdNoisy()
four_rooms_door_key_env = FourRoomsGridWolrdDoorKey()
maze_env = MazeGridWorld()
pinball_env = PinballEnvironment('Environments/pinball_simple_single_modified.cfg.txt')


mountain_car_env = gym.make('MountainCar-v0')
mountain_car_env._max_episode_steps = 1000000
minigrid_env = gym.make('MiniGrid-Empty-5x5-v0')
minigrid_env.max_steps = 1000000 # This should be a big number so that I handle cutoff manually
minigrid_env = RGBImgObsWrapper(minigrid_env)
minigrid_env = ImgObsWrapper(minigrid_env)
minigrid_door_key_env = gym.make('MiniGrid-DoorKey-5x5-v0')
minigrid_door_key_env.max_steps = 1000000 # This should be a big number so that I handle cutoff manually
minigrid_door_key_env = RGBImgObsWrapper(minigrid_door_key_env)
minigrid_door_key_env = ImgObsWrapper(minigrid_door_key_env)
puddle_world_env = PuddleEnv()



env_dict = {'acrobot_fully_observable': acrobot_env,
            'acrobot_partially_observable': acrobot_env,
            'four_rooms': four_rooms_env,
            'four_rooms_noisy': four_rooms_noisy_env,
            'four_rooms_door_key': four_rooms_door_key_env,
            'maze': maze_env,
            'mountain_car': mountain_car_env,
            'minigrid': minigrid_env,
            'minigrid_door_key': minigrid_door_key_env,
            'pinball': pinball_env,
            'puddle_world': puddle_world_env}
maxes_dict = {'acrobot_fully_observable': np.array([np.pi, np.pi, np.pi * 4, np.pi * 9]),
              'acrobot_partially_observable': np.concatenate((np.ones(4) * np.pi, np.array([2]))),
              'four_rooms': np.ones(four_rooms_env._grid.shape[0]**2),
              'four_rooms_noisy': np.ones(four_rooms_env._grid.shape[0]**2),
              'four_rooms_door_key': np.ones(four_rooms_env._grid.shape[0]**2 + 2),
              'maze': np.ones(maze_env._grid.shape[0]**2),
              'mountain_car': np.array([0.6, 0.07]),
              'minigrid': None,
              'minigrid_door_key': None,
              'pinball': np.array([1, 1, 1.5, 1.5]),
              'puddle_world': np.array([1, 1])}
mins_dict = {'acrobot_fully_observable': np.array([-np.pi, -np.pi, -np.pi * 4, -np.pi * 9]),
              'acrobot_partially_observable': np.concatenate((np.ones(4) * -np.pi, np.array([0]))),
              'four_rooms': np.zeros(four_rooms_env._grid.shape[0]**2),
              'four_rooms_noisy': np.zeros(four_rooms_env._grid.shape[0]**2),
              'four_rooms_door_key': np.zeros(four_rooms_env._grid.shape[0]**2 + 2),
              'maze': np.zeros(maze_env._grid.shape[0]**2),
              'mountain_car': np.array([-1.2, -0.07]),
              'minigrid': None,
              'minigrid_door_key': None,
              'pinball': np.array([0, 0, -1.5, -1.5]),
              'puddle_world': np.array([0, 0])}

def convert_to_angle(cos_, sin_):
    theta1 = np.arcsin(sin_)
    theta2 = np.arccos(cos_)
    output = theta1
    if theta1 >= 0.0 and theta1 <= np.pi/2 and theta2 >= 0.0 and theta2 <= np.pi/2:
        output =  theta1
    elif theta1 >= 0.0 and theta1 <= np.pi/2 and theta2 >= np.pi/2 and theta2 <= np.pi:
        output = theta2
    elif theta1 >= (-np.pi/2) and theta1 <= 0.0 and theta2 >= np.pi / 2 and theta2 <= np.pi:
        output = np.pi - theta1
    elif theta1 >= (-np.pi / 2) and theta1 <= 0.0 and theta2 >= 0.0 and theta2 <= np.pi / 2:
        output = - theta1
    if output > np.pi:
        output = output - 2 * np.pi
    return output

def convert_input(input):
    output = np.zeros(4)
    output[0] = convert_to_angle(input[0], input[1])
    output[1] = convert_to_angle(input[2], input[3])
    output[2:4] = input[4:6]
    return output

def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

def compute_return(cumulants, gammas):
    _return = 0
    gamma_product = 1
    for i in np.arange(cumulants.shape[0]):
        _return += gamma_product * cumulants[i]
        gamma_product *= gammas[i]
    return _return

