#!/usr/bin/env python3
import logging
import os
import coloredlogs
import skimage.io

coloredlogs.install(logging.INFO)

from torch.utils.data._utils.collate import default_collate
import torch
from train.behavioral_cloning.datasets.agent_state import AgentState
from train.behavioral_cloning.datasets.minerl_dataset import MineRLDataset
from train.behavioral_cloning.datasets.ds import ds_s32_orig_minerl_diamond_binact_softcam_48_aug0 as dataset
from train.behavioral_cloning.models.LightCnnLstmBinActsSoftCam import Network as net
from train.envs.minerl_env import make_minerl

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def plot(run):
    obs = env.reset()
    for i in range(250):
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        os.makedirs("imgs/{}".format(run), exist_ok=True)
        for x in range(obs.shape[0]):
            skimage.io.imsave("imgs/{}/{}_{:05d}_{}.png".format(run, x, i, rewards[x]), obs[x, :, :, 9:])
    env.reset()


# multithreaded environment
if __name__ == "__main__":
    n_cpu = 2
    env_id = 'MineRLObtainDiamond-v0'
    agent_states = (AgentState(sequence_length=dataset.SEQ_LENGTH, data_transform=dataset.DATA_TRANSFORM),) * n_cpu
    action_space = dataset.ACTION_SPACE
    input_space = dataset.INPUT_SPACE
    model = net().to(DEVICE)
    lessons_buffer = MineRLDataset(root=None, sequence_len=dataset.SEQ_LENGTH, train=True, prepare=False,
                                   experiment="MineRLObtainDiamond-v0",
                                   input_space=input_space,
                                   action_space=action_space,
                                   data_transform=dataset.DATA_TRANSFORM)
    env = make_minerl(env_id=env_id,
                      n_cpu=n_cpu,
                      seq_len=dataset.SEQ_LENGTH,
                      transforms=dataset.DATA_TRANSFORM,
                      input_space=dataset.INPUT_SPACE)
    state = env.reset()

    for i in range(1500):
        # collate, to tensor, to device
        input_dict = default_collate(state)
        for k in input_dict.keys():
            input_dict[k] = input_dict[k].to(DEVICE)

        # predict next action
        out_dict = model.forward(input_dict)

        # translate predictions to environment actions
        actions = action_space.logits_to_dict(env.action_space.noop(), out_dict)

        # evaluate model
        value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = \
            action_space.evaluate_actions(out_dict, actions)

        # take env step
        state, reward, done, info = env.step(actions)
        print(reward)

        # TODO: collect full trajectories, then dump into lessons buffer
        # lessons_buffer.add_sequence((state[j], actions[j], reward[j], None, done[j]))

    env.close()
