import gym
from gym.wrappers import Monitor
# from option import Option
from env.mujoco_env.reacher_env import ReacherGymEnv
from env.mujoco_env.reacher_env import ReacherGymEnvEval
import os
import torch
import gym
from monitor import Monitor
from option import *
import numpy as np
import contextlib

###############
# Load Option #
###############
option_load_path = os.path.join(os.environ['LOF_PKG_PATH'], 'experiments', 'red', 'pyt_save', 'model1000.pt')
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
    option = Option(option_load_path)

##################
# Define Problem #
##################
def make_subgoals_reacher(env):
    # name, prop_index, subgoal_index, state
    all_info = env.env.env.get_info()
    red_goal = Subgoal('red', 0, 0, all_info['red_p'])
    green_goal = Subgoal('green', 1, 1, all_info['green_p'])
    blue_goal = Subgoal('blue', 2, 2, all_info['blue_p'])
    yellow_goal = Subgoal('yellow', 3, 3, all_info['yellow_p'])
    # initial_state = all_info['ee_p']

    return [red_goal, green_goal, blue_goal, yellow_goal]

def make_taskspec_reacher():
    # go to R or G, then B, then Y, then R
    spec = 'F( (r|g) & F( b & F (y & F r)))'

    nF = 5
    nP = 10
    tm = np.zeros((nF, nF, nP))

    # S0
    #    r  g  b  y  e
    # 0  0  0  1  1  1
    # 1  1  1  0  0  0
    # 2  0  0  0  0  0
    # 3  0  0  0  0  0
    # G  0  0  0  0  0
    tm[0, 1, 0] = 1
    tm[0, 1, 1] = 1
    tm[0, 0, 2:] = 1
    # S1
    #    r  g  b  y  e
    # 0  0  0  0  0  0
    # 1  1  1  0  1  1
    # 2  0  0  1  0  0
    # 3  0  0  0  0  0
    # G  0  0  0  0  0
    tm[1, 1, 0] = 1
    tm[1, 1, 1] = 1
    tm[1, 2, 2] = 1
    tm[1, 1, 3] = 1
    tm[1, 1, 4] = 1
    # S2
    #    r  g  b  y  e
    # 0  0  0  0  0  0
    # 1  0  0  0  0  0
    # 2  1  1  1  0  1
    # 3  0  0  0  1  0
    # G  0  0  0  0  0
    tm[2, 2, 0] = 1
    tm[2, 2, 1] = 1
    tm[2, 2, 2] = 1
    tm[2, 3, 3] = 1
    tm[2, 2, 4] = 1
    # S3
    #    r  g  b  y  e
    # 0  0  0  0  0  0
    # 1  0  0  0  0  0
    # 2  0  0  0  0  0
    # 3  0  1  1  1  1
    # G  1  0  0  0  0
    tm[3, 4, 0] = 1
    tm[3, 3, 1] = 1
    tm[3, 3, 2] = 1
    tm[3, 3, 3] = 1
    tm[3, 3, 4] = 1
    # G
    #    r  g  b  y  e
    # 0  0  0  0  0  0
    # 1  0  0  0  0  0
    # 2  0  0  0  0  0
    # 3  0  0  0  0  0
    # G  1  1  1  1  1
    tm[4, 4, :] = 1

    task_state_costs = [1, 1, 1, 1, 0]

    safety_props = []
    task_spec = TaskSpec(spec, tm, task_state_costs)

    return task_spec, safety_props

# composite task
# (F((r|g) & F(b & F y)) & G ! can) | (F((r|g) & F y) & F can) & G ! o
def make_taskspec_reacher_composite():
    # go to A or B, then C, then HOME, unless C is CANceled in which case just go to A or B then HOME
    spec = '(F((a|b) & F(c & F home)) & G ! can) | (F((a|b) & F home) & F can) & G ! o'

    # prop order:
    # a b c home can cana canb canc canh o e

    nF = 7
    nP = 10
    tm = np.zeros((nF, nF, nP))

    # S0
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  1  1  0  0  0  0  0  1
    # 1  1  1  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  1  0  0  1  1  0
    # 3  0  0  0  0  0  1  1  0  0  0
    # 4  0  0  0  0  0  0  0  0  0  0
    # 5  0  0  0  0  0  0  0  0  0  0
    # G  0  0  0  0  0  0  0  0  0  0
    tm[0, 1, 0] = 1
    tm[0, 1, 1] = 1
    tm[0, 0, 2] = 1
    tm[0, 0, 3] = 1
    tm[0, 2, 4] = 1
    tm[0, 3, 5] = 1
    tm[0, 3, 6] = 1
    tm[0, 2, 7] = 1
    tm[0, 2, 8] = 1
    tm[0, 0, 9] = 1
    # S1
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  1  1  0  0  0  0  0  0  0  1
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  0  0  1  0  1  1  1  1  0  0
    # 4  0  0  0  1  0  0  0  0  0  0
    # 5  0  0  0  0  0  0  0  0  0  0
    # G  0  0  0  0  0  0  0  0  1  0
    tm[1, 1, 0] = 1
    tm[1, 1, 1] = 1
    tm[1, 3, 2] = 1
    tm[1, 4, 3] = 1
    tm[1, 3, 4] = 1
    tm[1, 3, 5] = 1
    tm[1, 3, 6] = 1
    tm[1, 3, 7] = 1
    tm[1, 6, 8] = 1
    tm[1, 1, 9] = 1
    # S2
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  0  1  1  1  0  0  1  1  1
    # 3  1  1  0  0  0  1  1  0  0  0
    # 4  0  0  0  0  0  0  0  0  0  0
    # 5  0  0  0  0  0  0  0  0  0  0
    # G  0  0  0  0  0  0  0  0  0  0
    tm[2, 3, 0] = 1
    tm[2, 3, 1] = 1
    tm[2, 2, 2] = 1
    tm[2, 2, 3] = 1
    tm[2, 2, 4] = 1
    tm[2, 3, 5] = 1
    tm[2, 3, 6] = 1
    tm[2, 2, 7] = 1
    tm[2, 2, 8] = 1
    tm[2, 2, 9] = 1
    # S3
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  1  1  1  0  1  1  1  1  0  1
    # 4  0  0  0  0  0  0  0  0  0  0
    # 5  0  0  0  0  0  0  0  0  0  0
    # G  0  0  0  1  0  0  0  0  1  0
    tm[3, 3, 0] = 1
    tm[3, 3, 1] = 1
    tm[3, 3, 2] = 1
    tm[3, 6, 3] = 1
    tm[3, 3, 4] = 1
    tm[3, 3, 5] = 1
    tm[3, 3, 6] = 1
    tm[3, 3, 7] = 1
    tm[3, 6, 8] = 1
    tm[3, 3, 9] = 1
    # S4
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  0  0  0  0  0  0  0  0  0  0
    # 4  1  1  0  1  0  0  0  0  0  1
    # 5  0  0  1  0  0  0  0  0  0  0
    # G  0  0  0  0  1  1  1  1  1  0
    tm[4, 4, 0] = 1
    tm[4, 4, 1] = 1
    tm[4, 5, 2] = 1
    tm[4, 4, 3] = 1
    tm[4, 6, 4] = 1
    tm[4, 6, 5] = 1
    tm[4, 6, 6] = 1
    tm[4, 6, 7] = 1
    tm[4, 6, 8] = 1
    tm[4, 4, 9] = 1
    # S5
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  0  0  0  0  0  0  0  0  0  0
    # 4  0  0  0  0  0  0  0  0  0  0
    # 5  1  1  1  0  0  0  0  0  0  1
    # G  0  0  0  1  1  1  1  1  1  0
    tm[5, 5, 0] = 1
    tm[5, 5, 1] = 1
    tm[5, 5, 2] = 1
    tm[5, 6, 3] = 1
    tm[5, 6, 4] = 1
    tm[5, 6, 5] = 1
    tm[5, 6, 6] = 1
    tm[5, 6, 7] = 1
    tm[5, 6, 8] = 1
    tm[5, 5, 9] = 1
    # G
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  0  0  0  0  0  0  0  0  0  0
    # 4  0  0  0  0  0  0  0  0  0  0
    # 5  0  0  0  0  0  0  0  0  0  0
    # G  1  1  1  1  1  1  1  1  1  1
    tm[6, 6, :] = 1

    # remember that these are multiplicative
    task_state_costs = [1, 1, 1, 1, 1, 1, 0]

    safety_props = [4, 5, 6, 7, 8]
    task_spec = TaskSpec(spec, tm, task_state_costs)

    return task_spec, safety_props

# OR task
# F ((a | b) & F c) & G ! o
def make_taskspec_reacher_or():
    # go to A, then B, then C, then HOME
    spec = 'F ((a | b) & F c) & G ! o'

    # prop order:
    # a b c home can cana canb canc canh o e

    nF = 3
    nP = 11
    tm = np.zeros((nF, nF, nP))

    # S0
    #    a  b  c  h  c ca cb cc ch  e
    # 0  0  0  1  1  1  0  0  1  1  1
    # 1  1  1  0  0  0  1  1  0  0  0
    # G  0  0  0  0  0  0  0  0  0  0
    tm[0, 1, 0] = 1
    tm[0, 1, 1] = 1
    tm[0, 0, 2] = 1
    tm[0, 0, 3] = 1
    tm[0, 0, 4] = 1
    tm[0, 1, 5] = 1
    tm[0, 1, 6] = 1
    tm[0, 0, 7] = 1
    tm[0, 0, 8] = 1
    tm[0, 0, 9] = 1
    # S1
    #    a  b  c  h  c ca cb cc ch  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  1  1  0  1  1  1  1  0  1  1
    # G  0  0  1  0  0  0  0  1  0  0
    tm[1, 1, 0] = 1
    tm[1, 1, 1] = 1
    tm[1, 2, 2] = 1
    tm[1, 1, 3] = 1
    tm[1, 1, 4] = 1
    tm[1, 1, 5] = 1
    tm[1, 1, 6] = 1
    tm[1, 2, 7] = 1
    tm[1, 1, 8] = 1
    tm[1, 1, 9] = 1
    # G
    #    a  b  c  h  c ca cb cc ch  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # G  1  1  1  1  1  1  1  1  1  1
    tm[2, 2, :] = 1

    # remember that these are multiplicative
    task_state_costs = [1, 1, 0]

    safety_props = [4, 5, 6, 7, 8]
    task_spec = TaskSpec(spec, tm, task_state_costs)

    return task_spec, safety_props

# sequential task
# F(a & F (b & (F c & F h))) & G ! o
def make_taskspec_reacher_sequential():
    # go to A, then B, then C, then HOME
    spec = 'F(a & F (b & (F c & F h))) & G ! o'

    # prop order:
    # a b c home can cana canb canc canh o e

    nF = 5
    nP = 10
    tm = np.zeros((nF, nF, nP))

    # S0
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  1  1  1  1  0  1  1  1  1
    # 1  1  0  0  0  0  1  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  0  0  0  0  0  0  0  0  0  0
    # G  0  0  0  0  0  0  0  0  0  0
    tm[0, 1, 0] = 1
    tm[0, 0, 1] = 1
    tm[0, 0, 2] = 1
    tm[0, 0, 3] = 1
    tm[0, 0, 4] = 1
    tm[0, 1, 5] = 1
    tm[0, 0, 6] = 1
    tm[0, 0, 7] = 1
    tm[0, 0, 8] = 1
    tm[0, 0, 9] = 1
    # S1
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  1  0  1  1  1  1  0  1  1  1
    # 2  0  1  0  0  0  0  1  0  0  0
    # 3  0  0  0  0  0  0  0  0  0  0
    # G  0  0  0  0  0  0  0  0  0  0
    tm[1, 1, 0] = 1
    tm[1, 2, 1] = 1
    tm[1, 1, 2] = 1
    tm[1, 1, 3] = 1
    tm[1, 1, 4] = 1
    tm[1, 1, 5] = 1
    tm[1, 2, 6] = 1
    tm[1, 1, 7] = 1
    tm[1, 1, 8] = 1
    tm[1, 1, 9] = 1
    # S2
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  1  1  0  1  1  1  1  0  1  1
    # 3  0  0  1  0  0  0  0  1  0  0
    # G  0  0  0  0  0  0  0  0  0  0
    tm[2, 2, 0] = 1
    tm[2, 2, 1] = 1
    tm[2, 3, 2] = 1
    tm[2, 2, 3] = 1
    tm[2, 2, 4] = 1
    tm[2, 2, 5] = 1
    tm[2, 2, 6] = 1
    tm[2, 3, 7] = 1
    tm[2, 2, 8] = 1
    tm[2, 2, 9] = 1
    # S3
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  1  1  1  0  1  1  1  1  0  1
    # G  0  0  0  1  0  0  0  0  1  0
    tm[3, 3, 0] = 1
    tm[3, 3, 1] = 1
    tm[3, 3, 2] = 1
    tm[3, 4, 3] = 1
    tm[3, 3, 4] = 1
    tm[3, 3, 5] = 1
    tm[3, 3, 6] = 1
    tm[3, 3, 7] = 1
    tm[3, 4, 8] = 1
    tm[3, 3, 9] = 1
    # G
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  0  0  0  0  0  0  0  0  0  0
    # G  1  1  1  1  1  1  1  1  1  1
    tm[4, 4, :] = 1

    # remember that these are multiplicative
    task_state_costs = [1, 1, 1, 1, 0]

    safety_props = [4, 5, 6, 7, 8]
    task_spec = TaskSpec(spec, tm, task_state_costs)

    return task_spec, safety_props

# IF task
# (F (a & F c) & G ! can) | (F a & F can)
def make_taskspec_reacher_if():
    # go to A, then B, then C, then HOME
    spec = '(F (r & F b) & G ! can) | (F r & F can) & G ! o'

    # prop order:
    # a b c home can cana canb canc canh o e

    nF = 5
    nP = 10
    tm = np.zeros((nF, nF, nP))

    # S0
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  1  0  1  0  0  0  0  0  1
    # 1  1  0  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  0  0  1  0  1  0  1  1  1  0
    # G  0  0  0  0  0  1  0  0  0  0
    tm[0, 1, 0] = 1
    tm[0, 0, 1] = 1
    tm[0, 3, 2] = 1
    tm[0, 0, 3] = 1
    tm[0, 3, 4] = 1
    tm[0, 4, 5] = 1
    tm[0, 3, 6] = 1
    tm[0, 3, 7] = 1
    tm[0, 3, 8] = 1
    tm[0, 0, 9] = 1
    # S1
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  1  1  0  1  0  0  0  0  0  1
    # 2  0  0  1  0  0  0  0  0  0  0
    # 3  0  0  0  0  0  0  0  0  0  0
    # G  0  0  0  0  1  1  1  1  1  0
    tm[1, 1, 0] = 1
    tm[1, 1, 1] = 1
    tm[1, 2, 2] = 1
    tm[1, 1, 3] = 1
    tm[1, 4, 4] = 1
    tm[1, 4, 5] = 1
    tm[1, 4, 6] = 1
    tm[1, 4, 7] = 1
    tm[1, 4, 8] = 1
    tm[1, 1, 9] = 1
    # S2
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  1  1  1  0  0  0  0  0  1
    # 3  0  0  0  0  0  0  0  0  0  0
    # G  1  0  0  0  1  1  1  1  1  0
    tm[2, 4, 0] = 1
    tm[2, 2, 1] = 1
    tm[2, 2, 2] = 1
    tm[2, 2, 3] = 1
    tm[2, 4, 4] = 1
    tm[2, 4, 5] = 1
    tm[2, 4, 6] = 1
    tm[2, 4, 7] = 1
    tm[2, 4, 8] = 1
    tm[2, 2, 9] = 1
    # S3
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  0  1  1  1  1  0  1  1  1  1
    # G  1  0  0  0  0  1  0  0  0  0
    tm[3, 4, 0] = 1
    tm[3, 3, 1] = 1
    tm[3, 3, 2] = 1
    tm[3, 3, 3] = 1
    tm[3, 3, 4] = 1
    tm[3, 4, 5] = 1
    tm[3, 3, 6] = 1
    tm[3, 3, 7] = 1
    tm[3, 3, 8] = 1
    tm[3, 3, 9] = 1
    # G
    #    r  g  b  y  c cr cg cb cy  e
    # 0  0  0  0  0  0  0  0  0  0  0
    # 1  0  0  0  0  0  0  0  0  0  0
    # 2  0  0  0  0  0  0  0  0  0  0
    # 3  0  0  0  0  0  0  0  0  0  0
    # G  1  1  1  1  1  1  1  1  1  1
    tm[4, 4, :] = 1

    # remember that these are multiplicative
    task_state_costs = [1, 1, 1, 1, 0]

    safety_props = [4, 5, 6, 7, 8]
    task_spec = TaskSpec(spec, tm, task_state_costs)

    return task_spec, safety_props

def make_safetyspecs_reacher():
    return []

#################
# Construct Env #
#################
task_spec, safety_props = make_taskspec_reacher_sequential()
# task_spec, safety_props = make_taskspec_reacher_if()

safety_specs = make_safetyspecs_reacher()
# subgoals = make_subgoals_reacher(env)
cancel_chance = 1
env = ReacherGymEnvEval(task_spec, cancel_chance,training=False, env_config={'headless': False, 'horizon': 800})
subgoals = env.subgoals
env = Monitor(env, './video', video_callable=lambda episode_id: True, force=True)

###############
# Run Rollout #
###############
def run_rollout(policy, env, num_episodes):
    option_to_color = {0: 'r', 1: 'g', 2: 'b', 3: 'y'}
    goal_state = task_spec.nF - 1
    max_num_steps = 800

    for i in range(num_episodes):
        task_done = False
        R = 0
        obs = env.reset(color='r')

        metapolicy = ContinuousMetaPolicy(subgoals, task_spec, safety_props, safety_specs, env, policy)

        f = 0
        env_f = 0
        num_steps = 0
        stop_high_level = False
        prev_option = -1
        prev_f = f

        while not task_done:
            if not stop_high_level:
                option = metapolicy.get_option(env, f)
                print(f, option_to_color[option], env_f)
            if prev_option == option and terminated:
                stop_high_level = True
            prev_option = option

            prev_f = f
            while prev_f == f and not task_done:
                env.render()

                a = policy.get_action(torch.from_numpy(obs).float())
                color = option_to_color[option]
                obs, reward, task_done, info = env.step(a, color=color)
                # print("FSA: {} | Goal: {} | reward {}".format(f, color, reward))
                env_f = info['f']
                prev_f = f
                f = metapolicy.get_fsa_state(env, f)
                R += reward

                # print(f, option)

                num_steps += 1
                # print(f)
                if f == goal_state:
                    env.set_task_done(True)

                state = tuple(env.all_info['ee_p'])
                if metapolicy.is_terminated(env, state, option):
                    terminated = True
                    break
                else:
                    terminated = False
                
        print(f"Episode {i} return: {R} | FSA: {f}")

    env.close()

#######
# Run #        
#######
run_rollout(option, env, num_episodes=10)