
import os
from film.util import helper
from film.util.helper import rollout
import imageio
from film.rl.model.mlp import DLMMLP
from film.rl.model.rnn import DLM_RNN
from film.rl.model.diffusion import DLM_Diffusion
import robomimic.utils.obs_utils as ObsUtils
import pickle

obs_spec = dict(
    obs=dict(
            low_dim=[
                    "object",
                    "robot0_eef_pos",
                    "robot0_eef_quat",
                    "robot0_gripper_qpos",
                    ],
            rgb=[],
        ),
)
ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=obs_spec)
obs_keys = ["object", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"]   

def get_model(model_name,input_dim):
    if 'mlp' in model_name:
        model = DLMMLP(input_dim=input_dim, hidden_dims=[1024]*5, output_dim=7, obs_keys = obs_keys)
    elif 'rnn' in model_name:
        model = DLM_RNN(input_dim = input_dim, hidden_dim = 400,
                        num_layers = 2, output_dim = 7,
                        obs_keys = obs_keys, rnn_horizon = 8)
    elif 'diffusion' in model_name:
        model = DLM_Diffusion(
            output_dim = 7,
            input_dim = input_dim,
            denoising_steps=100,
            prediction_horizon=16,
            action_horizon=8,
            obs_horizon = 2
            )
    return model

def try_rollout(model,save_path,dataset_path):
    

    # create a video writer
    video_path = "rollout.mp4"
    video_writer = imageio.get_writer(video_path, fps=20)

    # you can change this while debugging
    num_rollouts = 2

    success_rate = rollout(model,
                        dataset_path,
                        horizon = 400,
                        video_writer = video_writer,
                        obs_keys = obs_keys,
                        num_rollouts = num_rollouts)
    print("Success rate over {} rollouts: {}".format(num_rollouts, success_rate))

    # # prepare to write playback trajectories to video
    video_path = os.path.join(save_path, "playback.mp4")

    helper.playback_demos(video_path, dataset_path, num_rollouts = 2)