import gym

from agents.distral_1col0.distral1col import trainDistral
import environments
from environments import ReacherMultistageFixedTaskEnv
from environments.wrapper import DiscreteWrapper

g = 0

models, policy, episode_rewards, episode_durations = trainDistral(
    list_of_envs=[
        DiscreteWrapper(
            gym.make("ReacherMultistageFixedTask-v0", task_id=i, goal_type=g),
            n_actions=ReacherMultistageFixedTaskEnv.get_n_discrete_actions(),
            disc2cont=ReacherMultistageFixedTaskEnv.disc2cont,
        )
        for i in range(3)
    ]
)
