import os
import gym
import torch
import importlib
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from torch.utils.data._utils.collate import default_collate

from train.pytorch_wrapper.prediction import Predictor

from train.behavioral_cloning.spaces.action_spaces import BINARY_ACTIONS
from train.behavioral_cloning.run_train import parse_args, get_arg_string, build_model, load_data


def get_dataset(args):
    module = os.path.splitext(args.dataset)[0].replace("/", ".")
    dataset = importlib.import_module(module)
    return dataset


def get_noop():
    # TODO: fix this!
    from minerl.env import obtain_action_space
    from copy import deepcopy
    noop_action = deepcopy(obtain_action_space.no_op())
    for k in noop_action.keys():
        noop_action[k] = 0
    noop_action["camera"] = np.zeros([], np.float32)
    return noop_action


def eval_model(args):
    """
    train model
    """

    # load data
    # ---------
    print("Loading data ...")
    train_dataset, valid_dataset, _, _ = load_data(args)
    data_set = train_dataset if args.evalset == "train" else valid_dataset
    dataset = get_dataset(args)

    # init data loader
    # ----------------
    # data_loader = torch.utils.data.DataLoader(data_set, batch_size=500, shuffle=False, drop_last=True, num_workers=4)
    data_samples = default_collate([data_set[i] for i in range(1000)])
    data_loader = [data_samples]
    truth_actions = [data_set.get_raw_observation(i)[1] for i in range(1000)]

    # init model
    # ----------
    model = build_model(args)

    # init prediction
    # ---------------
    arg_str = get_arg_string(args)
    if args.checkpoint:
        arg_str = "%s_cp%06d" % (arg_str, args.checkpoint)
    param_file = os.path.join(args.param_root, "params_%s.npy" % arg_str)
    print("param_file:", param_file)
    predictor = Predictor(model, param_file=param_file)

    # run evaluation
    # --------------
    print("Running evaluation ...")
    for i_sample, data in enumerate(tqdm.tqdm(data_loader)):

        # predict model output
        out_dict = predictor.forward(data["inputs"])

        # collect actions for steps
        predicted_actions = []
        for step in range(out_dict["logits"].shape[0]):

            step_out_dict = dict()
            for key in out_dict.keys():
                step_out_dict[key] = out_dict[key][step:step+1]
            step_action = dataset.ACTION_SPACE.logits_to_dict(get_noop(), step_out_dict)[0]
            predicted_actions.append(step_action)

        # stack binary actions
        selected_bin_actions = np.zeros((len(predicted_actions), len(BINARY_ACTIONS)))
        selected_cam_actions = np.zeros((len(predicted_actions), 2))
        truth_bin_actions = np.zeros((len(predicted_actions), len(BINARY_ACTIONS)))
        truth_cam_actions = np.zeros((len(predicted_actions), 2))
        for i_step, action in enumerate(predicted_actions):
            selected_bin_actions[i_step, :] = np.asarray([action[ba] for ba in BINARY_ACTIONS])
            selected_cam_actions[i_step, :] = action["camera"].flatten()
            truth_bin_actions[i_step, :] = np.asarray([truth_actions[i_step][ba] for ba in BINARY_ACTIONS])
            truth_cam_actions[i_step, :] = truth_actions[i_step]["camera"].flatten()

        plt.figure("Camera Actions", figsize=(60, 20))
        plt.clf()
        plt.subplot(2, 1, 1)
        plt.plot(selected_cam_actions[:, 0], "b-")
        plt.plot(truth_cam_actions[:, 0], "m-", alpha=0.5)
        plt.grid()
        plt.subplot(2, 1, 2)
        plt.plot(selected_cam_actions[:, 1], "b-")
        plt.plot(truth_cam_actions[:, 1], "m-", alpha=0.5)
        plt.grid()
        fig_file = os.path.join(str(Path.home()), "cam_actions_%s.png" % arg_str)
        plt.savefig(fig_file)

        plt.figure("Discrete Actions", figsize=(60, 20))
        plt.clf()
        plt.subplot(2, 1, 1)
        plt.imshow(truth_bin_actions.T, interpolation="nearest", vmin=0, vmax=1, aspect="auto")
        plt.subplot(2, 1, 2)
        plt.imshow(selected_bin_actions.T, interpolation="nearest", vmin=0, vmax=1, aspect="auto")
        fig_file = os.path.join(str(Path.home()), "binary_actions_%s.png" % arg_str)
        plt.savefig(fig_file)

        exit(0)


if __name__ == "__main__":
    """ main """
    args = parse_args()
    eval_model(args)
