from time import time
from util import *
from robomimic import DATASET_REGISTRY
import robomimic.utils.file_utils as FileUtils
import os
import sys

# change path to your DLM folder
path_to_dlm = '/home/bho36/FILM/results'
sys.path.append(path_to_dlm)
sys.path.append(path_to_dlm+"/robomimic/")
sys.path.append(path_to_dlm+"/robosuite/")


input_dim_dict = get_folder(path_to_dlm)
output_dim = 7


def run_transformer(yes_train=True, yes_rollout=True):
    task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
        'tool_hang']

    seq_len = 10  # same as prediction horizon
    denoising_steps = 100
    action_horizon = 8
    obs_horizon = 2
    args = [input_dim, denoising_steps, action_horizon,
            obs_horizon, output_dim, seq_len]
    save_path = "trainings/transformer"
    if yes_train:
        model = train_transformer(path_to_dlm, download_folder, input_dim, output_dim,
                                  num_epochs=50, seq_len=seq_len, batch_size=100,
                                  denoising_steps=denoising_steps, action_horizon=action_horizon, obs_horizon=obs_horizon,
                                  save_name=save_path)

    else:
        model = get_model(name='transformer', args=args)
    model_name = ["epoch_50.pth"]
    if yes_rollout:
        test_rollout(model,
                     save_path=save_path,
                     dataset_path=dataset_path,
                     model_name=model_name,
                     obs_len=obs_horizon,
                     verbose="Transformer",
                     num_rollouts=2)
        # try_rollout(model, save_path, model_name[0], dataset_path)


def run_diffusion(yes_train=True, yes_rollout=True):
    task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
        'tool_hang']

    seq_len = 16  # same as prediction horizon
    denoising_steps = 100
    action_horizon = 8
    obs_horizon = 2
    args = [input_dim, denoising_steps, action_horizon,
            obs_horizon, output_dim, seq_len]
    save_path = "trainings/diffusion2"
    if yes_train:
        model = train_diffusion(path_to_dlm, download_folder, input_dim, output_dim,
                                num_epochs=101, seq_len=seq_len, batch_size=100,
                                denoising_steps=denoising_steps, action_horizon=action_horizon, obs_horizon=obs_horizon,
                                save_name=save_path)

    else:
        model = get_model(name='diffusion', args=args)
    model_name = ["epoch_100.pth"]
    if yes_rollout:
        test_rollout(model,
                     save_path=os.path.join(path_to_dlm, save_path),
                     dataset_path=dataset_path,
                     model_name=model_name,
                     obs_len=obs_horizon,
                     verbose="Diffusion",
                     num_rollouts=50)
        # try_rollout(model, os.path.join(path_to_dlm, save_path),
        #             model_name[0], dataset_path, obs_len=obs_horizon)
    # plot_fig(save_path)


def run_task_diff(yes_train=True, yes_rollout=True):
    task_list = ['tool_hang']
    denoising_steps_arr = [100]
    action_horizon = 8
    obs_horizon = 2
    typeB = [["epoch_50.pth", "epoch_100.pth", "epoch_150.pth",
              "epoch_200.pth", "epoch_250.pth"], 251]
    # typeB = [["epoch_50.pth"], 51]
    argument = []
    seq_len = 16

    for task_name in task_list:
        task, dataset_type, hdf5_type, input_dim, download_folder, dataset_path = input_dim_dict[
            task_name]
        model_name, num_epochs = typeB
        argument.append([task, num_epochs, denoising_steps_arr,
                        'Task Difficult ' + task])

        for jj, denoising_steps in enumerate(denoising_steps_arr):
            args = [input_dim, denoising_steps, action_horizon,
                    obs_horizon, output_dim, seq_len]
            save_path = "trainings/task_diff/"+task+"/"+str(denoising_steps)
            os.makedirs(save_path, exist_ok=True)
            print('Save Path: ', save_path)
            if yes_train:
                model = train_diffusion(path_to_dlm, download_folder, input_dim, output_dim,
                                        num_epochs=num_epochs, seq_len=seq_len, batch_size=100,
                                        denoising_steps=denoising_steps, action_horizon=action_horizon, obs_horizon=obs_horizon,
                                        save_name=save_path)
            else:
                model = get_model(name='diffusion', args=args)

            if yes_rollout:

                test_rollout(model,
                             save_path=os.path.join(path_to_dlm, save_path),
                             dataset_path=dataset_path,
                             model_name=model_name,
                             obs_len=obs_horizon,
                             verbose="Task Diff Diffusion")
            plot_fig(os.path.join(path_to_dlm, save_path), num_epochs)
    save_name = "trainings/task_diff/"
    plot_fig_all(path_to_dlm, save_name, argument)


if __name__ == '__main__':
    # yes_train = True
    yes_train = False

    yes_rollout = True
    # yes_rollout = False

    # run_mlp(yes_train=yes_train, yes_rollout=yes_rollout)
    # run_rnn(yes_train=yes_train, yes_rollout=yes_rollout)
    run_diffusion(yes_train=yes_train, yes_rollout=yes_rollout)
    # run_transformer(yes_train=yes_train, yes_rollout=yes_rollout)

    # run_task_diff(yes_train=yes_train,
    #               yes_rollout=yes_rollout)
    # run_task_qual(yes_train=yes_train,yes_rollout=yes_rollout,model_name='rnn')
    # run_seq_len(yes_train=yes_train,yes_rollout=yes_rollout,model_name='rnn')

    pass
