import json
import os

import hydra
import imageio
import metaworld
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import (DummyVecEnv, VecEnv, VecMonitor,
                                              VecVideoRecorder)

from experiments.envs.metaworld import (MetaWorldSafetySpeedWrapper,
                                        MetaWorldSawyerEnv,
                                        MetaWorldSawyerImageWrapper)


@hydra.main(version_base=None, config_path="../config", config_name="metaworld_train_policies")
def main(cfg):
    hydra_output = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
    for identity in [1, 2]:

        env = MetaWorldSawyerEnv(cfg.env)
        env = MetaWorldSafetySpeedWrapper(env, identity)
        # set render mode

        # Define the PPO model
        model = PPO('MlpPolicy', env, verbose=1, device='cuda')

        for i in range(cfg.num_policies):
            model.learn(total_timesteps=cfg.timesteps)
            # change to hydra dir
            # model.save(f"{cfg.env}_{identity}_{i}")
            model.save(f"{hydra_output}/{cfg.env}_{identity}_{i}")

        # Create a directory for saving videos
        video_folder = 'videos/'
        os.makedirs(video_folder, exist_ok=True)

        # Reset the environment to ensure VecVideoRecorder wraps a fresh environment
        env = MetaWorldSawyerEnv(cfg.env)
        # env = MetaWorldSawyerImageWrapper(env)
        env = MetaWorldSafetySpeedWrapper(env, identity)

        obs, _ = env.reset()
        images = []
        infos = []
        # Run the trained policy and save video
        if cfg.render:
            img = env.render()
        for _ in range(250):
            if cfg.render:
                images.append(img)
            action, _ = model.predict(obs)
            obs, rewards, dones, trunc, info = env.step(action)
            # convert info that is int32 to int
            info = {k: (int(v.item()) if isinstance(v, np.int32) else v)
                    for k, v in info.items()}
            infos.append(info)

            if (dones or trunc):
                obs, _ = env.reset()

            if cfg.render:
                img = env.render()

        # Save the video and close the environment
        env.close()

        if cfg.render:
            imageio.mimsave(os.path.join(
                "videos", f"{cfg.env}_{identity}_final.gif"), images, fps=30)  # Save as GIF

        # save infos as list of dicts (dump)
        with open(os.path.join(video_folder, f"{cfg.env}_{identity}_infos.json"), 'w') as f:
            json.dump(infos, f)


if __name__ == "__main__":
    main()  # Run the main function
