import gym
from gym.wrappers.time_limit import TimeLimit

from gen_rl.envs.wrapper_vec_envs import create_vector_environment
from gen_rl.policy.env_models import Model


def launch_env(args, if_train=True):
    if args["env_name"].lower() == "paint":
        from gen_rl.envs.paintGym.env import Paint
        env = Paint(args=args)
        eval_env = Paint(args=args)
        env.seed(seeds=args["env_seed"])
        args["env_decode_fn"] = env.decode
        args["state_dim"] = env.observation_space.shape[0]
        args["action_dim"] = env.action_space.shape[-1]
        args["max_action"] = float(env.action_space.high[0][0])
        args["random_act_fn"] = env.action_space.sample
    else:
        if args["env_name"].lower().startswith("classic"):
            args["env_name"] = args["env_name"].strip("classic-")
            env = gym.make(args["env_name"])
            args["state_dim"] = env.observation_space.shape[0]
            args["action_dim"] = env.action_space.shape[0]
            args["max_action"] = float(env.action_space.high[0])

            def make_env():
                env = gym.make(args["env_name"])
                env = TimeLimit(env=env, max_episode_steps=args["max_episode_steps"])
                return env

            env = create_vector_environment(make_env=make_env, args=args, if_train=True)
            args["random_act_fn"] = env.env.action_space.sample
            eval_env = create_vector_environment(make_env=make_env, args=args, if_train=False)
        elif args["env_name"] == "Pendulum-v0":  # this is for Ground-truth dynamics and reward models
            from gen_rl.envs.pendulum import PendulumEnv
            env = PendulumEnv(args=args)
            args["state_dim"] = env.observation_space.shape[0]
            args["action_dim"] = env.action_space.shape[0]
            args["max_action"] = float(env.action_space.high[0])

            def make_env():
                env = PendulumEnv(args=args)
                env = TimeLimit(env=env, max_episode_steps=args["max_episode_steps"])
                return env

            env = create_vector_environment(make_env=make_env, args=args, if_train=True)
            args["random_act_fn"] = env.env.action_space.sample
            eval_env = create_vector_environment(make_env=make_env, args=args, if_train=False)
        elif args["env_name"].lower() == "recsim":
            from gen_rl.envs.recsim.environments.interest_evolution_generic import create_vector_environment as make_env
            env = make_env(args=args)
            from copy import deepcopy
            eval_env = deepcopy(env)
            # eval_env = make_env(args=args)
            args["max_action"] = 0.0  # temp
            args["state_dim"] = args["recsim_dim_embed"]
            args["action_dim"] = args["recsim_dim_embed"]  # args["recsim_num_actions"]
            args["act_embed"] = env.act_embedding
        else:
            from gen_rl.envs.mujoco_envs.ant_v4 import AntEnv
            from gen_rl.envs.mujoco_envs.half_cheetah_v4 import HalfCheetahEnv
            from gen_rl.envs.mujoco_envs.hopper_v4 import HopperEnv
            from gen_rl.envs.mujoco_envs.humanoid_v4 import HumanoidEnv
            from gen_rl.envs.mujoco_envs.pusher_v4 import PusherEnv
            from gen_rl.envs.mujoco_envs.reacher_v4 import ReacherEnv
            from gen_rl.envs.mujoco_envs.swimmer_v4 import SwimmerEnv
            from gen_rl.envs.mujoco_envs.walker2d_v4 import Walker2dEnv

            def make_env():
                if args["env_name"].replace("mujoco-", "").lower() == "ant":
                    env = AntEnv(frame_size=(64, 64))
                elif args["env_name"].replace("mujoco-", "").lower() in ["halfcheetah", "cheetah"]:
                    env = HalfCheetahEnv(frame_size=(64, 64))
                elif args["env_name"].replace("mujoco-", "").lower() == "hopper":
                    env = HopperEnv(frame_size=(64, 64))
                elif args["env_name"].replace("mujoco-", "").lower() == "humanoid":
                    env = HumanoidEnv(frame_size=(64, 64))
                elif args["env_name"].replace("mujoco-", "").lower() == "pusher":
                    env = PusherEnv(frame_size=(64, 64))
                elif args["env_name"].replace("mujoco-", "").lower() == "reacher":
                    env = ReacherEnv(frame_size=(64, 64))
                elif args["env_name"].replace("mujoco-", "").lower() == "swimmer":
                    env = SwimmerEnv(frame_size=(64, 64))
                elif args["env_name"].replace("mujoco-", "").lower() == "walker2d":
                    env = Walker2dEnv(frame_size=(64, 64))
                else: raise ValueError

                env = TimeLimit(env=env, max_episode_steps=args["max_episode_steps"])
                return env

            env = create_vector_environment(make_env=make_env, args=args, if_train=True)
            eval_env = create_vector_environment(make_env=make_env, args=args, if_train=False)

            # import gym
            # env = gym.vector.make(args["env_name"].replace("mujoco-", ""),  # width=64, height=64,
            #                       num_envs=args["num_envs"], asynchronous=args["if_async"])
            # env = VecWrapper(env=env, max_episode_steps=args["max_episode_steps"], if_pomdp=True)
            # env.seed(args["env_seed"])

            args["state_dim"] = env.env.observation_space.shape[-1]
            args["action_dim"] = env.env.action_space[0].shape[0]
            args["max_action"] = float(env.env.action_space[0].high[0])
            args["random_act_fn"] = env.env.action_space.sample

    args["gaussian_noise_std"] *= args["max_action"]
    return env, eval_env, args


def launch_models(env, args):
    reward_model = None
    if args["if_train_reward_model"]:
        reward_model = Model(dim_out=1, args=args).to(args["device"])
    else:
        # from gen_rl.envs.differentiable_mujoco.torch_block import launch_state_reward_models
        # _, reward_model = launch_state_reward_models(env=env, max_ep_steps=args["max_episode_steps"],
        #                                              device=args["device"])
        if args["env_name"] == "Pendulum-v0":
            reward_model = env.env.envs[0].reward_model
    if args["if_train_state_model"]:
        if args["env_name"].lower() == "paint":
            from gen_rl.policy.cvae_cnn import cVAE
            from gen_rl.policy.env_models import DIM_LATENT
            # state_model = Model(dim_out=args["state_dim"], args=args).to(args["device"])
            # Size is fixed by the PaintGym Env!
            state_model = cVAE(shape=(3, 128, 128),
                               dim_action=args["action_dim"],
                               dim_out_latent=DIM_LATENT,
                               dim_condition=32,
                               device=args["device"]).to(args["device"])
        else:
            state_model = Model(dim_out=args["state_dim"], args=args).to(args["device"])
        print(state_model)
    else:
        from gen_rl.envs.differentiable_mujoco.torch_block import launch_state_reward_models
        state_model, _ = launch_state_reward_models(env=env, max_ep_steps=args["max_episode_steps"],
                                                    device=args["device"])
        if args["env_name"] == "Pendulum-v0":
            state_model = env.env.envs[0].state_transition_model
    print("state_model", state_model)
    print("reward_model", reward_model)
    return state_model, reward_model
