import os
from film.util.helper import rollout, vlm_rollout, load_data_for_training
import imageio
import robomimic.utils.obs_utils as ObsUtils
import pickle

obs_spec = dict(
    obs=dict(
            low_dim=[
                    "object",
                    "robot0_eef_pos",
                    "robot0_eef_quat",
                    "robot0_gripper_qpos",
                    ],
            rgb=[],
        ),
)
ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=obs_spec)
obs_keys = ["object", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"]   


def get_folder(path_to_dlm):
    # set download folder for EASY task
    lift_folder = os.path.join(path_to_dlm, 'lift/ph')
    os.makedirs(lift_folder, exist_ok=True)
    # enforce that the dataset exists
    lift_path = os.path.join(lift_folder, "low_dim_v141.hdf5")
    assert os.path.exists(lift_path)


    # set download folder for MEDIUM task
    square_folder = os.path.join(path_to_dlm, 'square/ph')
    os.makedirs(square_folder, exist_ok=True)
    # enforce that the dataset exists
    square_path = os.path.join(square_folder, "low_dim_v141.hdf5")
    assert os.path.exists(square_path)


    # set download folder for HARD task
    tool_hang_folder = os.path.join(path_to_dlm, 'tool_hang/ph')
    os.makedirs(tool_hang_folder, exist_ok=True)
    # enforce that the dataset exists
    tool_hang_path = os.path.join(tool_hang_folder, "low_dim_v141.hdf5")
    assert os.path.exists(tool_hang_path)



    input_dim_dict = {
    'square': ['square','ph', "low_dim", 23,square_folder,square_path],
    'lift': ['lift','ph', "low_dim", 19,lift_folder,lift_path],
    'tool_hang': ['tool_hang','ph', "low_dim", 53,tool_hang_folder,tool_hang_path],
                }
    return input_dim_dict


def get_dataset(download_folder,seq_len=1,batch_size=100):
    dataset_path = os.path.join(download_folder, "low_dim_v141.hdf5")
     
    train_loader, valid_loader = load_data_for_training(
        dataset_path=dataset_path,
        obs_keys=obs_keys,
        seq_len=seq_len,
        batch_size=batch_size
    )

    return dataset_path,train_loader,valid_loader

def try_rollout(model,save_path,model_file,video_path,dataset_path):
    model.load(os.path.join(save_path, model_file))

    # create a video writer
    video_file =  os.path.join(video_path,"rollout.mp4")
    video_writer = imageio.get_writer(video_file, fps=20)

    # you can change this while debugging
    num_rollouts = 2

    success_rate = rollout(model,
                        dataset_path,
                        horizon = 400,
                        video_writer = video_writer,
                        obs_keys = obs_keys,
                        num_rollouts = num_rollouts)
    print("Success rate over {} rollouts: {}".format(num_rollouts, success_rate))

    video_writer.close()

def try_vlm_rollout(model,vlm,video_path,dataset_path,save_path=None,subtask_ind=None,instruction=None,horizon=400,obs_len=1,generate_img=False,generate_vid=False,video_skip=5,img_skip=5,generate_rollout_vid=False,num_rollouts=1,figure_path="",language_path="",verbose=False):

    # you can change this while debugging
    num_rollouts = num_rollouts

    success_rate,full_success, rollout_results = vlm_rollout(model,
                                            dataset_path,
                                            vlm = vlm,
                                            instruction=instruction,
                                            subtask_ind=subtask_ind,
                                            horizon = horizon,
                                            generate_img=generate_img,
                                            generate_vid=generate_vid,
                                            video_skip=video_skip,
                                            img_skip=img_skip,
                                            generate_rollout_vid = generate_rollout_vid,
                                            obs_keys = obs_keys,
                                            obs_len=obs_len,
                                            num_rollouts = num_rollouts,
                                            figure_path=figure_path,
                                            video_path=video_path,
                                            language_path=language_path,
                                            verbose=verbose)
    print("Success rate over {} rollouts: {}".format(num_rollouts, success_rate))

    if save_path is not None:
        result_file = os.path.join(save_path,'data.pkl')
        print("Saving Results Files on ", result_file)
        with open(result_file, "wb") as f:    
            pickle.dump([full_success,rollout_results], f)