import metaworld
import random
import time
import gym
from pathlib import Path

from omegaconf import OmegaConf
from stable_baselines3 import PPO, SAC
from common.utils.metaworld_utils import MetaWorldEnv

logdir = Path("collect_data/metaworld/logs")
logdir.mkdir(exist_ok=True)

model_dir = Path("collect_data/metaworld/saved_models")
model_dir.mkdir(exist_ok=True)


if __name__ == "__main__":
    print(metaworld.ML1.ENV_NAMES)  # Check out the available environments
    MODEL_CLS = SAC

    args = OmegaConf.create({
        "task": "window-open-v2",
        "steps": 1e4,
        "test": False,
        "num_test_episodes": 10,
        "render": False,
    })
    args = OmegaConf.merge(args, OmegaConf.from_cli())

    env = MetaWorldEnv(args.task)
    eval_env = MetaWorldEnv(args.task)

    if args.test:
        model = MODEL_CLS.load(
            # path=f"collect_data/metaworld/saved_models/{args.task}",
            path=f"collect_data/metaworld/logs/{args.task}.log/best_model.zip",
            print_system_info=True,
        )
    else:
        model = MODEL_CLS(
            policy="MlpPolicy",
            env=env,
            verbose=True,
        )
        model.learn(
            total_timesteps=int(args.steps),
            log_interval=20,
            progress_bar=True,
            eval_env=eval_env,
            eval_freq=10000,
            n_eval_episodes=10,
            eval_log_path=f"collect_data/metaworld/logs/{args.task}.log",
        )
        model.save(path=f"collect_data/metaworld/saved_models/{args.task}")

    frames = []
    for episode in range(1, args.num_test_episodes + 1):
        start = time.time()
        obs = env.reset()  # Reset environment
        env.render(offscreen=True)

        done = False
        t = 0
        cumulative_reward = 0
        while not done:
            a, _ = model.predict(obs)  # Sample an action
            obs, reward, done, info = env.step(
                a)  # Step the environoment with the sampled random action
            t += 1
            cumulative_reward += reward
            if args.render:
                env.render()
            else:
                frame = env.render(offscreen=True)
                frames.append(frame)
        end = time.time()

        elapsed_time = end - start
        fps = t / elapsed_time
        success = bool(info["success"])

        print(
            f"Episode {episode}/{args.num_test_episodes} finished after {t:d} steps"
        )
        print(f"Total Return {cumulative_reward:.2f}")
        print(f"Elapsed Time {elapsed_time:.4f} seconds")
        print(f"FPS {fps:.2f}")
        print(f"Success: {success}")
        print()

    if len(frames) > 0:
        import imageio.v2
        imageio.mimsave(f"{args.task}.mp4", frames, fps=60)
