import os
import pickle
from film.util.helper import train, rollout, load_data_for_training
import imageio
from film.rl.model.diffusion import DLM_Diffusion
from film.rl.model.transformer import DLM_Transformer
import robomimic.utils.obs_utils as ObsUtils
from IPython.display import Video
import numpy as np
import matplotlib.pyplot as plt
import torch
import pickle

obs_spec = dict(
    obs=dict(
        low_dim=[
            "object",
            "robot0_eef_pos",
            "robot0_eef_quat",
            "robot0_gripper_qpos",
        ],
        rgb=[],
    ),
)
ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=obs_spec)
obs_keys = ["object", "robot0_eef_pos",
            "robot0_eef_quat", "robot0_gripper_qpos"]


def get_folder(path_to_dlm):

    # square folder
    download_folder = os.path.join(path_to_dlm, 'datasets/square/ph')
    os.makedirs(download_folder, exist_ok=True)
    dataset_path = os.path.join(download_folder, "low_dim_v141.hdf5")

    # set download folder for EASY task
    lift_folder = os.path.join(path_to_dlm, 'datasets/lift/ph')
    os.makedirs(lift_folder, exist_ok=True)
    # enforce that the dataset exists
    lift_path = os.path.join(lift_folder, "low_dim_v141.hdf5")
    assert os.path.exists(lift_path)

    # set download folder for HARD task
    tool_hang_folder = os.path.join(path_to_dlm, 'datasets/tool_hang/ph')
    os.makedirs(tool_hang_folder, exist_ok=True)
    # enforce that the dataset exists
    tool_hang_path = os.path.join(tool_hang_folder, "low_dim_v141.hdf5")
    assert os.path.exists(tool_hang_path)

    input_dim_dict = {
        'square': ['square', 'ph', "low_dim", 23, download_folder, dataset_path],
        'tool_hang': ['tool_hang', 'ph', "low_dim", 53, tool_hang_folder, tool_hang_path],
    }
    return input_dim_dict


def get_dataset(download_folder, seq_len=1, batch_size=100):
    dataset_path = os.path.join(download_folder, "low_dim_v141.hdf5")

    train_loader, valid_loader = load_data_for_training(
        dataset_path=dataset_path,
        obs_keys=obs_keys,
        seq_len=seq_len,
        batch_size=batch_size
    )

    return dataset_path, train_loader, valid_loader


def get_model(name='mlp', args=[]):
    if 'diffusion' in name:
        input_dim, denoising_steps, action_horizon, obs_horizon, output_dim, seq_len = args
        model = DLM_Diffusion(
            output_dim=output_dim,
            input_dim=input_dim,
            denoising_steps=denoising_steps,
            prediction_horizon=seq_len,
            action_horizon=action_horizon,
            obs_horizon=obs_horizon
        )
    elif 'transformer' in name:
        input_dim, denoising_steps, action_horizon, obs_horizon, output_dim, seq_len = args
        model = DLM_Transformer(
            output_dim=output_dim,
            input_dim=input_dim,
            denoising_steps=denoising_steps,
            prediction_horizon=seq_len,
            action_horizon=action_horizon,
            obs_horizon=obs_horizon
        )
    return model


def try_rollout(model, save_path, name, dataset_path, obs_len=1,):
    model.load(os.path.join(save_path, name))

    # create a video writer
    video_path = "rollout.mp4"
    video_writer = imageio.get_writer(video_path, fps=20)

    # you can change this while debugging
    num_rollouts = 2

    success_rate = rollout(model,
                           dataset_path,
                           horizon=700,
                           video_writer=video_writer,
                           obs_keys=obs_keys,
                           num_rollouts=num_rollouts,
                           obs_len=obs_len)
    print("Success rate over {} rollouts: {}".format(num_rollouts, success_rate))

    video_writer.close()


def train_transformer(path_to_dlm, dataset_path, input_dim, output_dim,
                      num_epochs=51, seq_len=16, batch_size=100,
                      denoising_steps=100, action_horizon=8, obs_horizon=2,
                      save_name="trainings/transformer"):

    _, train_loader, valid_loader = get_dataset(
        dataset_path, seq_len, batch_size)

    model = DLM_Transformer(
        output_dim=output_dim,
        input_dim=input_dim,
        denoising_steps=denoising_steps,
        prediction_horizon=seq_len,
        action_horizon=action_horizon,
        obs_horizon=obs_horizon
    )

    # model epochs saved to save_path/epoch_x.pth where x is every 50 epochs
    save_path = os.path.join(path_to_dlm, save_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    print(save_path)

    controller_name = os.listdir(save_path)
    cp = 0
    # if 'epoch_0.pth' in controller_name:
    #     controller_name = [i for i in controller_name if '0.pth' in i]
    #     controller_id = controller_id = sorted(
    #         [int(i.split('epoch_')[1].split('.')[0]) for i in controller_name])
    #     cp = controller_id[-1]
    #     cp_path = os.path.join(save_path, f'epoch_{cp}.pth')
    #     print("Loading Checkpoint: ", cp_path)
    #     model.load(cp_path)
    #     model.epoch = cp

    # train_losses and valid losses are lists of (loss, epoch) tupples
    train_losses, valid_losses = train(
        model, train_loader, valid_loader, start_epoch=cp, num_epochs=num_epochs-cp, save_path=save_path)

    data_file = os.path.join(save_path, "loss.pkl")
    with open(data_file, "wb") as outfile:
        pickle.dump([valid_losses, train_losses], outfile)

    return model


def train_diffusion(path_to_dlm, dataset_path, input_dim, output_dim,
                    num_epochs=51, seq_len=16, batch_size=100,
                    denoising_steps=100, action_horizon=8, obs_horizon=2,
                    save_name="trainings/diffusion"):

    _, train_loader, valid_loader = get_dataset(
        dataset_path, seq_len, batch_size)

    model = DLM_Diffusion(
        output_dim=output_dim,
        input_dim=input_dim,
        denoising_steps=denoising_steps,
        prediction_horizon=seq_len,
        action_horizon=action_horizon,
        obs_horizon=obs_horizon
    )

    # model epochs saved to save_path/epoch_x.pth where x is every 50 epochs
    save_path = os.path.join(path_to_dlm, save_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    print(save_path)

    controller_name = os.listdir(save_path)
    cp = 0
    if 'epoch_0.pth' in controller_name:
        controller_name = [i for i in controller_name if '0.pth' in i]
        controller_id = controller_id = sorted(
            [int(i.split('epoch_')[1].split('.')[0]) for i in controller_name])
        cp = controller_id[-1]
        cp = 51
        cp_path = os.path.join(save_path, f'epoch_{cp}.pth')
        print("Loading Checkpoint: ", cp_path)
        model.load(cp_path)
        model.epoch = cp

    # train_losses and valid losses are lists of (loss, epoch) tupples
    train_losses, valid_losses = train(
        model, train_loader, valid_loader, start_epoch=cp, num_epochs=num_epochs-cp, save_path=save_path)

    data_file = os.path.join(save_path, "loss.pkl")
    with open(data_file, "wb") as outfile:
        pickle.dump([valid_losses, train_losses], outfile)

    return model


def test_rollout(model, save_path, dataset_path, model_name, obs_len=1, verbose="", num_rollouts=50):
    # model epochs saved to save_path/epoch_x.pth where x is every 50 epochs
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    success = [0]
    for name in model_name:
        print('-----------------------------')
        print("Task Name {} Model Name {} ".format(dataset_path, name))
        print("Model: ", verbose)
        print('Save Folder', save_path)
        print('-----------------------------')
        # TESTING THE MODEL WITH THE LOWEST VALIDATION ERROR
        model.load(os.path.join(save_path, name))

        # rollout
        success_rate = rollout(model,
                               dataset_path,
                               horizon=700,
                               video_writer=None,
                               obs_keys=obs_keys,
                               num_rollouts=num_rollouts,
                               obs_len=obs_len)
        success.append(success_rate)
        print("Success rate over {} rollouts: {}".format(
            num_rollouts, success_rate))

    if num_rollouts >= 50:
        data_file = os.path.join(save_path, "success.pkl")
        with open(data_file, "wb") as outfile:
            pickle.dump(success, outfile)


def plot_fig(save_path, epoch_per_iter=50):
    data_file = os.path.join(save_path, "loss.pkl")
    data_file = open(data_file, "rb")
    train_losses, valid_losses = pickle.load(data_file)
    vl = torch.tensor(valid_losses).cpu().numpy()
    tl = torch.tensor(train_losses).cpu().numpy()

    data_file = os.path.join(save_path, "success.pkl")
    data_file = open(data_file, "rb")
    success = pickle.load(data_file)
    ss = np.array(success)
    epoch = epoch_per_iter*np.arange(len(ss))

    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    ax[0].plot(tl[:, 1], tl[:, 0], linewidth=2, label='Training Loss')
    ax[0].plot(vl[:, 1], vl[:, 0], linewidth=2, label='Validation Loss')
    ax[0].set_xlabel("Epochs")
    ax[0].set_ylabel("Loss")
    ax[0].set_title("Training & Validation Loss (BC)")
    ax[0].legend()

    ax[1].plot(epoch, ss, linewidth=2)
    ax[1].set_xlabel("Epochs")
    ax[1].set_ylabel("Success Rate")
    ax[1].set_title("Success Rate")
    ax[1].set_ylim([0, 1.05])

    fig.show()
    save_file = os.path.join(save_path, 'result.png')
    fig.savefig(save_file)
    pass


def plot_fig_all(path_to_dlm, save_name, args):

    save_path = os.path.join(path_to_dlm, save_name)

    fig, ax = plt.subplots(1, len(args), figsize=(6*len(args), 6))
    for ii in range(len(args)):
        task_name, num_epochs, param, title_name = args[ii]
        for d in param:
            file_path = os.path.join(save_path, task_name+"/"+str(d))

            # data_file = os.path.join(save_path, "loss.pkl")
            # data_file= open(data_file, "rb")
            # valid_losses,train_losses= pickle.load(data_file)
            success_file = os.path.join(file_path, "success.pkl")
            success_file = open(success_file, "rb")
            success = pickle.load(success_file)

            ss = np.array(success)
            epoch = np.linspace(0, num_epochs-1, len(ss))

            ax[ii].plot(epoch, ss, linewidth=2, label=str(d))
            ax[ii].set_xlabel("Epochs", fontsize=16)
            ax[ii].set_title(title_name)
            ax[ii].set_ylim([0, 1.05])
            if ii == 0:
                ax[ii].set_ylabel("Success Rate", fontsize=16)
            else:
                ax[ii].set_yticks([])
            ax[ii].legend(param)
    fig.tight_layout()
    fig.savefig(save_path + '/result.png')
    print('Saving Figure')
