import unittest
import pickle
import os
import sys

sys.path.append('./film/robomimic/')
sys.path.append('./film/robosuite/')


from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import h5py

import robomimic.utils.file_utils as FileUtils
from robomimic import DATASET_REGISTRY
import robomimic.utils.obs_utils as ObsUtils
from film.util import helper
from film.util.util import get_folder, get_dataset, try_vlm_rollout
from IPython.display import Video
import robomimic.utils.obs_utils as ObsUtils
from film.rl.model_util import get_model
from film.rl.model import mlp, rnn, diffusion


def print_obs(dataset_path):
    # open file
    f = h5py.File(dataset_path, "r")

    # each demonstration is a group under "data".  each demonstration is named "demo_#" where # is a number, starting from 0
    demos = list(f["data"].keys())
    num_demos = len(demos)

    print("hdf5 file {} has {} demonstrations".format(dataset_path, num_demos))

    # look at first demonstration
    demo_key = demos[0]
    demo_grp = f["data/{}".format(demo_key)]

    # actions is a num numpy array of shape (time, action dim)
    actions = demo_grp["actions"][:]
    print("shape of actions {}".format(actions.shape))

    #Each observation is a dictionary that maps modalities to numpy arrays of shape (time, obs modality dim)
    print("observations:")
    for obs, obs_key in demo_grp["obs"].items():
        print("{} - shape {}".format(obs , obs_key.shape))


class TestCaseBase(unittest.TestCase):

    def setUp(self):
        file_path = Path(__file__).parent
        self.vid_path = file_path / "video"  #Store the video for saving
        self.fix_path = file_path / "fixture"  #Store the fix model path
        self.figure_path = file_path / "figure"  #Store the fix model path
        self.download_path = 'results/datasets'  #Store the data
        self.result_path = file_path / "results" 
        self.instruction_step_path = file_path / "instruction"  #Store the fix model path
        self.vlm_path = 'vlm'  #VLM data
        self.instruction_path = 'instruction'


        os.makedirs(self.vid_path, exist_ok=True)
        os.makedirs(self.fix_path, exist_ok=True)
        os.makedirs(self.figure_path, exist_ok=True)
        os.makedirs(self.download_path, exist_ok=True)
        os.makedirs(self.result_path, exist_ok=True)
        self.task_list = ['tool_hang']
        self.task_list_subtask_ind = {'tool_hang': [42, 43]}
        self.model_list = ['rnn']
        self.vlm_list = ['gpt4']

    def assertDownload(self, task="lift"):
        # set download folder
        download_folder = os.path.join(self.download_path, task, 'ph')
        os.makedirs(download_folder, exist_ok=True)

        #download the dataset
        dataset_path = os.path.join(download_folder, "low_dim_v141.hdf5")
        dataset_type = "ph"
        hdf5_type = "low_dim"
        if not os.path.exists(dataset_path):
            FileUtils.download_url(
                url=DATASET_REGISTRY[task][dataset_type][hdf5_type]["url"],
                download_dir=download_folder,
            )
        # enforce that the dataset exists

        assert os.path.exists(dataset_path)
        print('\nTask Description: ', task)
        print_obs(dataset_path)

    def assertModelLoad(self, model_name, input_dim, save_path):
        model = get_model(model_name, input_dim)
        model.load(os.path.join(save_path, "model.pth"))
        return model



class TestRobomimic(TestCaseBase):
    """Tests for the RoboMimic"""

    # def test_image_gen(self):
    #     input_dim_dict = get_folder(self.download_path)
    #     horizon = 10
    #     model_horizon = 0
    #     num_rollouts = 2
    #     generate_img=True
    #     vlm = None
    #     for model_name in self.model_list:
    #         print("Algorithm to be used: ", model_name)
    #         for task in self.task_list:
    #             task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
    #                 task]

    #             #Get the model save path
    #             save_path = os.path.join(self.fix_path, task, 'ph', model_name)
    #             model = self.assertModelLoad(model_name, input_dim, save_path)
    #             if model_name == 'mlp':
    #                 self.assertIsInstance(model, mlp.DLMMLP)
    #             elif model_name == 'rnn':
    #                 self.assertIsInstance(model, rnn.DLM_RNN)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #             elif model_name == 'diffusion':
    #                 self.assertIsInstance(model, diffusion.DLM_Diffusion)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #             model.load(os.path.join(save_path, "model.pth"))

    #             with self.subTest("Testing Videos"):
    #                 video_path = os.path.join(self.vid_path, task, 'ph',
    #                                           model_name)
    #                 figure_path = os.path.join(self.figure_path, task, 'ph',
    #                                            model_name)
    #                 os.makedirs(figure_path, exist_ok=True)
    #                 try_vlm_rollout(model,
    #                                 vlm,
    #                                 video_path,
    #                                 dataset_path,
    #                                 generate_img=generate_img,
    #                                 horizon=horizon,
    #                                 num_rollouts=num_rollouts,
    #                                 figure_path=figure_path)

    # def test_vid_gen(self):
    #     input_dim_dict = get_folder(self.download_path)
    #     horizon = 10
    #     model_horizon = 0
    #     num_rollouts = 2
    #     generate_vid=True
    #     vlm = None
    #     for model_name in self.model_list:
    #         print("Algorithm to be used: ", model_name)
    #         for task in self.task_list:
    #             task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
    #                 task]

    #             #Get the model save path
    #             save_path = os.path.join(self.fix_path, task, 'ph', model_name)
    #             model = self.assertModelLoad(model_name, input_dim, save_path)
    #             if model_name == 'mlp':
    #                 self.assertIsInstance(model, mlp.DLMMLP)
    #             elif model_name == 'rnn':
    #                 self.assertIsInstance(model, rnn.DLM_RNN)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #             elif model_name == 'diffusion':
    #                 self.assertIsInstance(model, diffusion.DLM_Diffusion)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #             model.load(os.path.join(save_path, "model.pth"))

    #             with self.subTest("Testing Videos"):
    #                 video_path = os.path.join(self.vid_path, task, 'ph',
    #                                           model_name)
    #                 figure_path = os.path.join(self.figure_path, task, 'ph',
    #                                            model_name)
    #                 instruction_step_path = os.path.join(self.instruction_step_path, task, 'ph',
    #                                            model_name)
    #                 os.makedirs(figure_path, exist_ok=True)
    #                 try_vlm_rollout(model,
    #                                 vlm,
    #                                 video_path,
    #                                 dataset_path,
    #                                 generate_vid=generate_vid,
    #                                 horizon=horizon,
    #                                 num_rollouts=num_rollouts,
    #                                 figure_path=figure_path,
    #                                 instruction_step_path=instruction_step_path)



    # def test_data(self):
    #     input_dim_dict = get_folder(self.download_path)
    #     horizon = 10
    #     model_horizon = 0
    #     num_rollouts = 2
    #     generate_img=False
    #     vlm = None
    #     for model_name in self.model_list:
    #         print("Algorithm to be used: ", model_name)
    #         for task in self.task_list:
    #             task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
    #                 task]

    #             #Get the model save path
    #             save_path = os.path.join(self.fix_path, task, 'ph', model_name)
    #             model = self.assertModelLoad(model_name, input_dim, save_path)
    #             if model_name == 'mlp':
    #                 self.assertIsInstance(model, mlp.DLMMLP)
    #             elif model_name == 'rnn':
    #                 self.assertIsInstance(model, rnn.DLM_RNN)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #             elif model_name == 'diffusion':
    #                 self.assertIsInstance(model, diffusion.DLM_Diffusion)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #             model.load(os.path.join(save_path, "model.pth"))

    #             with self.subTest("Testing Videos"):
    #                 video_path = os.path.join(self.vid_path, task, 'ph',
    #                                           model_name)
    #                 figure_path = os.path.join(self.figure_path, task, 'ph',
    #                                            model_name)
    #                 result_path = os.path.join(self.result_path, task, 'ph',
    #                                            model_name)
    #                 os.makedirs(figure_path, exist_ok=True)
    #                 os.makedirs(result_path, exist_ok=True)
    #                 try_vlm_rollout(model,
    #                                 vlm,
    #                                 video_path,
    #                                 dataset_path,
    #                                 result_path=result_path,
    #                                 generate_img=generate_img,
    #                                 horizon=horizon,
    #                                 num_rollouts=num_rollouts,
    #                                 figure_path=figure_path)
                    
    #             result_file = os.path.join(result_path,'data.pkl')
    #             with open(result_file, "rb") as f:
    #                 result_data = pickle.load(f)
    #             results = result_data[1]
    #             replan_data = results[0]['replan']
    #             self.assertEqual(len(replan_data),horizon)

    #             subtask_data = results[0]['subtask']
    #             self.assertEqual(len(subtask_data),horizon)

    # def test_no_vlm(self):
    #     input_dim_dict = get_folder(self.download_path)
    #     horizon = 10
    #     model_horizon = horizon
    #     num_rollouts = 2

    #     # TODO: Load VLM
    #     vlm_name = self.vlm_list[0]
    #     vlm = None
    #     for model_name in self.model_list:
    #         print("Algorithm to be used: ", model_name)
    #         for task in self.task_list:
    #             task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
    #                 task]

    #             #Get the subtask index of the observation for a specific task
    #             subtask_ind = self.task_list_subtask_ind[task]
                
    #             #Get the instructions
    #             instruction_path = os.path.join(self.instruction_path,task+'.txt')
    #             with open(instruction_path, 'r') as file:
    #                 content = file.read()
    #             instruction = content.split('\n')

    #             #Get the model save path
    #             save_path = os.path.join(self.fix_path, task, 'ph', model_name)
    #             model = self.assertModelLoad(model_name, input_dim, save_path)
    #             if model_name == 'mlp':
    #                 self.assertIsInstance(model, mlp.DLMMLP)
    #             elif model_name == 'rnn':
    #                 self.assertIsInstance(model, rnn.DLM_RNN)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #                     self.assertEqual(horizon, model.hidden_state_horizon)
    #             elif model_name == 'diffusion':
    #                 self.assertIsInstance(model, diffusion.DLM_Diffusion)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #                     self.assertEqual(horizon, model.action_horizon)
    #             model.load(os.path.join(save_path, "model.pth"))
                

    #             with self.subTest("Testing Videos"):
    #                 video_path = os.path.join(self.vid_path, task, 'ph',
    #                                           model_name)
    #                 figure_path = os.path.join(self.figure_path, task, 'ph',
    #                                            model_name)
    #                 result_path = os.path.join(self.result_path, task, 'ph',
    #                                            model_name)
    #                 os.makedirs(figure_path, exist_ok=True)
    #                 os.makedirs(result_path, exist_ok=True)
    #                 try_vlm_rollout(model,
    #                                 vlm,
    #                                 video_path,
    #                                 dataset_path,
    #                                 result_path,
    #                                 subtask_ind=subtask_ind,
    #                                 instruction = instruction,
    #                                 horizon=horizon,
    #                                 num_rollouts=num_rollouts,
    #                                 figure_path=figure_path)

    # def test_vid_gen(self):
    #     input_dim_dict = get_folder(self.download_path)
    #     horizon = 50
    #     model_horizon = 0
    #     num_rollouts = 2
    #     generate_vid=True
    #     vlm = None
    #     for model_name in self.model_list:
    #         print("Algorithm to be used: ", model_name)
    #         for task in self.task_list:
    #             task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
    #                 task]
                
    #             #Get the subtask index of the observation for a specific task
    #             subtask_ind = self.task_list_subtask_ind[task]
                
    #             #Get the instructions
    #             instruction_path = os.path.join(self.instruction_path,task+'.txt')
    #             with open(instruction_path, 'r') as file:
    #                 content = file.read()
    #             instruction = content.split('\n')

    #             #Get the model save path
    #             save_path = os.path.join(self.fix_path, task, 'ph', model_name)
    #             model = self.assertModelLoad(model_name, input_dim, save_path)
    #             if model_name == 'mlp':
    #                 self.assertIsInstance(model, mlp.DLMMLP)
    #             elif model_name == 'rnn':
    #                 self.assertIsInstance(model, rnn.DLM_RNN)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #             elif model_name == 'diffusion':
    #                 self.assertIsInstance(model, diffusion.DLM_Diffusion)
    #                 if model_horizon > 0:
    #                     model.set_horizon(horizon)
    #             model.load(os.path.join(save_path, "model.pth"))

    #             with self.subTest("Testing Videos"):
    #                 video_path = os.path.join(self.vid_path, task, 'ph',
    #                                           model_name)
    #                 figure_path = os.path.join(self.figure_path, task, 'ph',
    #                                            model_name)
    #                 result_path = os.path.join(self.result_path, task, 'ph',
    #                                            model_name)
    #                 instruction_step_path = os.path.join(self.instruction_step_path, task, 'ph',
    #                                            model_name)
    #                 os.makedirs(figure_path, exist_ok=True)
    #                 try_vlm_rollout(model,
    #                                 vlm,
    #                                 video_path,
    #                                 dataset_path,
    #                                 result_path,
    #                                 subtask_ind=subtask_ind,
    #                                 instruction = instruction,
    #                                 generate_vid=generate_vid,
    #                                 horizon=horizon,
    #                                 num_rollouts=num_rollouts,
    #                                 figure_path=figure_path,
    #                                 instruction_step_path=instruction_step_path)


    def test_no_vlm(self):
        input_dim_dict = get_folder(self.download_path)
        horizon = 400
        model_horizon = 0
        num_rollouts = 2
        generate_img=True
        generate_vid=False
        vlm = None
        for model_name in self.model_list:
            print("Algorithm to be used: ", model_name)
            for task in self.task_list:
                task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
                    task]
                
                #Get the subtask index of the observation for a specific task
                subtask_ind = self.task_list_subtask_ind[task]
                
                #Get the instructions
                instruction_path = os.path.join(self.instruction_path,task+'.txt')
                with open(instruction_path, 'r') as file:
                    content = file.read()
                instruction = content.split('\n')

                #Get the model save path
                save_path = os.path.join(self.fix_path, task, 'ph', model_name)
                model = self.assertModelLoad(model_name, input_dim, save_path)
                if model_name == 'mlp':
                    self.assertIsInstance(model, mlp.DLMMLP)
                elif model_name == 'rnn':
                    self.assertIsInstance(model, rnn.DLM_RNN)
                    if model_horizon > 0:
                        model.set_horizon(horizon)
                elif model_name == 'diffusion':
                    self.assertIsInstance(model, diffusion.DLM_Diffusion)
                    if model_horizon > 0:
                        model.set_horizon(horizon)
                model.load(os.path.join(save_path, "model.pth"))

                with self.subTest("Testing Videos"):
                    video_path = os.path.join(self.vid_path, task, 'ph',
                                              model_name)
                    figure_path = os.path.join(self.figure_path, task, 'ph',
                                               model_name)
                    result_path = os.path.join(self.result_path, task, 'ph',
                                               model_name)
                    instruction_step_path = os.path.join(self.instruction_step_path, task, 'ph',
                                               model_name)
                    os.makedirs(figure_path, exist_ok=True)
                    try_vlm_rollout(model,
                                    vlm,
                                    video_path,
                                    dataset_path,
                                    result_path,
                                    subtask_ind=subtask_ind,
                                    instruction = instruction,
                                    generate_img=generate_img,
                                    generate_vid=generate_vid,
                                    horizon=horizon,
                                    num_rollouts=num_rollouts,
                                    figure_path=figure_path,
                                    instruction_step_path=instruction_step_path)



if __name__ == '__main__':
    unittest.main()
