# This file loads and visualize the encoder
# You can load trajectories as much as you want

#!/usr/bin/env python3
import tempfile

import dowel_wrapper

assert dowel_wrapper is not None
import dowel

import wandb

import argparse
import datetime
import functools
import os
import sys
import platform
import torch.multiprocessing as mp

if "mac" in platform.platform():
    pass
else:
    os.environ["MUJOCO_GL"] = "egl"
    if "SLURM_STEP_GPUS" in os.environ:
        os.environ["EGL_DEVICE_ID"] = os.environ["SLURM_STEP_GPUS"]
    os.environ["EGL_DEVICE_ID"] = "0"  # fix

import better_exceptions
import numpy as np

better_exceptions.hook()

import torch

from garage import wrap_experiment
from garage.experiment.deterministic import set_seed
from garage.torch.distributions import TanhNormal

from garagei.replay_buffer.path_buffer_ex import PathBufferEx
from garagei.experiment.option_local_runner import OptionLocalRunner
from garagei.envs.consistent_normalized_env import consistent_normalize
from garagei.sampler.option_multiprocessing_sampler import OptionMultiprocessingSampler
from garagei.torch.modules.with_encoder import WithEncoder, Encoder
from garagei.torch.modules.gaussian_mlp_module_ex import (
    GaussianMLPTwoHeadedModuleEx,
    GaussianMLPIndependentStdModuleEx,
    GaussianMLPModuleEx,
)
from garagei.torch.modules.gaussian_lstm_module_ex import (
    GaussianLSTMTwoHeadedModuleEx,
    GaussianLSTMIndependentStdModuleEx,
    GaussianLSTMModuleEx,
)
from garagei.torch.modules.parameter_module import ParameterModule
from garagei.torch.policies.policy_ex import PolicyEx, RecurrentPolicyEx
from garagei.torch.q_functions.continuous_mlp_q_function_ex import (
    ContinuousMLPQFunctionEx,
)
from garagei.torch.q_functions.continuous_lstm_q_function_ex import (
    ContinuousLSTMQFunctionEx,
)
from garagei.torch.optimizers.optimizer_group_wrapper import OptimizerGroupWrapper
from garagei.torch.utils import xavier_normal_ex
from iod.metra import METRA
from iod.recurrent_metra import RecurrentMETRA
from iod.dads import DADS
from iod.utils import get_normalizer_preset


import sys

EXP_DIR = "exp/" if "SLURM_JOB_ID" not in os.environ else "/scratch/heatz123/exp/"
if os.environ.get("START_METHOD") is not None:
    START_METHOD = os.environ["START_METHOD"]
else:
    START_METHOD = "spawn"


from main import (
    get_argparser,
    get_exp_name,
    get_gaussian_module_construction,
    make_env,
    get_log_dir,
    get_runner,
)

args = get_argparser().parse_args()
g_start_time = int(datetime.datetime.now().timestamp())


@wrap_experiment(log_dir=get_log_dir(), name=get_exp_name()[0])
def run(ctxt=None):
    runner = get_runner(args, ctxt)
    max_path_length = args.max_path_length
    contextualized_make_env = functools.partial(
        make_env, args=args, max_path_length=max_path_length
    )
    runner.restore(
        # from_dir="exp/Exploration/sd000_1713081556_ant_hill_metra_hill_height-1-perpendicular-1024-rnd-exploration-check"
        # from_dir="exp/Exploration/sd000_1713351204_ant_hill_metra_hill_height-1-perpendicular-sac50000000",
        from_dir="exp/Exploration/sd000_1713366129_ant_hill_metra_hill_height-1-perpendicular-sac50000000-lowent",
        make_env=contextualized_make_env,
    )
    runner._algo.draw_one_plots(runner)


if __name__ == "__main__":
    mp.set_start_method(START_METHOD)
    run()
