from time import time
from film.util.util import *
from film.util.helper import generate_vlm_demos
from robomimic import DATASET_REGISTRY
import robomimic.utils.file_utils as FileUtils
import os
import sys

from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import h5py


import robomimic.utils.file_utils as FileUtils
from robomimic import DATASET_REGISTRY
import robomimic.utils.obs_utils as ObsUtils
from film.util import helper
from film.util.util import get_folder, get_dataset, try_vlm_rollout
from IPython.display import Video
import robomimic.utils.obs_utils as ObsUtils
from film.rl.model_util import get_model
from film.rl.model import mlp, rnn, diffusion


# change path to your DLM folder
path_to_dlm = '/home/bho36/FILM/results'
sys.path.append(path_to_dlm)
sys.path.append(path_to_dlm+"/robomimic/")
sys.path.append(path_to_dlm+"/robosuite/")


output_dim = 7
result_path = "results"
download_path = os.path.join(result_path, 'datasets')
instruction_path = os.path.join('instruction')
fixture_path = os.path.join('tests/fixture')


os.makedirs(result_path, exist_ok=True)
os.makedirs(download_path, exist_ok=True)
os.makedirs(result_path, exist_ok=True)
os.makedirs(fixture_path, exist_ok=True)


input_dim_dict = get_folder(download_path)


def ModelLoad(model_name, input_dim, save_path):
    model = get_model(model_name, input_dim)
    model.load(os.path.join(save_path, "model.pth"))
    return model


def make_videos():
    horizon = 400
    model_horizon = 0
    num_rollouts = -1
    generate_img = True
    generate_vid = False
    vlm = None
    task = 'tool_hang'
    dataset_path = '/home/bho36/FILM/results/datasets/tool_hang/ph/low_dim_v141.hdf5'

    figure_path = os.path.join('results', 'figure', task, 'ph',
                               'expert')
    os.makedirs(figure_path, exist_ok=True)
    generate_vlm_demos(figure_path, dataset_path, num_rollouts)


def run_rollout(vlm_name=None, model_name='rnn', num_rollouts=1, horizon=700, model_horizon=0, generate_img=False, generate_vid=False, generate_rollout_vid=False, img_skip=5, video_skip=5, verbose=False):
    task = 'tool_hang'
    subtask_ind = [42, 43]
    if vlm_name is None:
        vlm_name = 'noVLM'
        vlm = None
    else:
        # TODO: Integrate VLM
        vlm = True

    # Get the needed data from the task
    task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
        task]

    # #Get the instructions
    # instruction_file = os.path.join(instruction_path,task+'.txt')
    # with open(instruction_file, 'r') as file:
    #     content = file.read()
    # instruction = content.split('\n')

    # prompt_file = os.path.join(instruction_path,task+'_prompt.txt')
    # with open(instruction_file, 'r') as file:
    #     content = file.read()
    # instruction = content.split('\n')
    instruction = [
        "instruction/tool_hang.txt"
    ]
    if vlm is not None:
        if '1pass' in vlm_name:
            instruction = [
            "instruction/tool_hang.txt",
            "instruction/tool_hang_openai.txt",
            "instruction/tool_hang_1pass_prompt.txt",
        ]
        else:
            instruction = [
                "instruction/tool_hang.txt",
                "instruction/tool_hang_openai.txt",
                "instruction/tool_hang_new_prompt.txt",
                "instruction/tool_hang_new_prompt2.txt",
            ]

    # Get the model save path
    save_path = os.path.join(fixture_path, task, 'ph', model_name)
    model = ModelLoad(model_name, input_dim, save_path)
    obs_len = 1
    if model_name == 'mlp':
        pass
    elif model_name == 'rnn':
        if model_horizon > 0:
            model.set_horizon(model_horizon)
    elif model_name == 'diffusion':
        if model_horizon > 0:
            model.set_horizon(model_horizon)
        obs_len = model.obs_horizon

    task_path = os.path.join(result_path, task)
    model_path = os.path.join(task_path, model_name,
                              vlm_name, 'horizon_'+str(model_horizon))
    figure_path = os.path.join(model_path, 'figure')
    video_path = os.path.join(model_path, 'video')
    language_path = os.path.join(model_path, 'language')

    os.makedirs(task_path, exist_ok=True)
    os.makedirs(model_path, exist_ok=True)
    os.makedirs(figure_path, exist_ok=True)
    os.makedirs(video_path, exist_ok=True)
    os.makedirs(language_path, exist_ok=True)
    try_vlm_rollout(model,
                    vlm,
                    video_path,
                    dataset_path,
                    save_path=model_path,
                    subtask_ind=subtask_ind,
                    instruction=instruction,
                    generate_img=generate_img,
                    generate_vid=generate_vid,
                    video_skip=video_skip,
                    img_skip=img_skip,
                    generate_rollout_vid=generate_rollout_vid,
                    horizon=horizon,
                    obs_len=obs_len,
                    num_rollouts=num_rollouts,
                    figure_path=figure_path,
                    language_path=language_path,
                    verbose=verbose)


if __name__ == '__main__':
    # make_videos()

    # No VLM, MLP, Nominal horizon
    # run_rollout(vlm_name=None,model_name='mlp',num_rollouts=50,horizon=700,model_horizon=0,)

    # No VLM, RNN, Nominal horizon
    # run_rollout(vlm_name=None, model_name='rnn',
    #             num_rollouts=50, horizon=700, model_horizon=10,)
    # No VLM, RNN, Long horizon
    # run_rollout(vlm_name=None,model_name='rnn',num_rollouts=50,horizon=700,model_horizon=701,)

    # No VLM, Diffusion, Nominal horizon
    # run_rollout(vlm_name=None, model_name='diffusion',
    #             num_rollouts=50, horizon=700, model_horizon=0,)
    # run_rollout(vlm_name=None, model_name='diffusion', num_rollouts=50,
    #             horizon=700, model_horizon=0, generate_rollout_vid=True, verbose=True)
    # No VLM, Diffusion, Long horizon
    # run_rollout(vlm_name=None,model_name='diffusion',num_rollouts=50,horizon=700,model_horizon=701,)
    # run_rollout(vlm_name=None, model_name='diffusion', num_rollouts=50,
    #             horizon=700, model_horizon=700, generate_rollout_vid=True, verbose=True)

    # # GPT4 2pass, Diffusion, Nominal horizon
    # run_rollout(vlm_name='OpenAI', model_name='diffusion', num_rollouts=50,
    #             horizon=700, model_horizon=0, generate_img=True, img_skip=20, generate_rollout_vid=True)

    # # GPT4 1pass, Diffusion, Nominal horizon
    run_rollout(vlm_name='OpenAI_1pass', model_name='diffusion', num_rollouts=50,
                horizon=700, model_horizon=0, generate_img=True, img_skip=20, generate_rollout_vid=True)

    # # GPT4 2pass, RNN, Long horizon
    # run_rollout(vlm_name='OpenAI', model_name='rnn', num_rollouts=50, horizon=700,
    #             model_horizon=0, generate_img=True, img_skip=20, generate_rollout_vid=True)
    
