#!/usr/bin/env python3
import logging
import os
import coloredlogs
import skimage.io

coloredlogs.install(logging.INFO)

from torch.utils.data._utils.collate import default_collate
import torch
from train.behavioral_cloning.datasets.agent_state import AgentState
from train.behavioral_cloning.datasets.minerl_dataset import MineRLDataset
from train.behavioral_cloning.datasets.ds import ds_s32_orig_minerl_diamond_binact_softcam_48_aug0 as dataset
from train.envs.minerl_env import make_minerl
from tqdm import tqdm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# multithreaded environment
if __name__ == "__main__":
    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    agent_states = (AgentState(sequence_length=dataset.SEQ_LENGTH, data_transform=dataset.DATA_TRANSFORM),) * n_cpu
    action_space = dataset.ACTION_SPACE
    input_space = dataset.INPUT_SPACE

    lessons_buffer = MineRLDataset(root=None, sequence_len=dataset.SEQ_LENGTH, train=True, prepare=False,
                                   experiment="MineRLObtainDiamond-v0",
                                   input_space=input_space,
                                   action_space=action_space,
                                   data_transform=dataset.DATA_TRANSFORM)
    env = make_minerl(env_id=env_id,
                      n_cpu=n_cpu,
                      seq_len=dataset.SEQ_LENGTH,
                      transforms=dataset.DATA_TRANSFORM,
                      input_space=dataset.INPUT_SPACE)
    state = env.reset()

    for i in tqdm(range(1500)):
        # take env step
        state, reward, done, info = env.step((env.action_space.noop(),) * n_cpu)

    print(reward)

    env.close()
