import datetime
import os
import sys
import random
import time
import copy
import json
import glob
from os.path import join as pjoin

import numpy as np

from alfworld.info import ALFWORLD_DATA
import alfworld.agents.environment
import alfworld.agents.modules.generic as generic
from alfworld.agents.agent import TextDAggerAgent

from promptrl.envs.alfworld_viz_env import AlfworldVizEnv

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar'
controller_type = 'viz_sref'
#train_or_eval = "train"
train_or_eval = "eval_out_of_distribution"

def collect_data(task_types):

    time_1 = datetime.datetime.now()
    config = generic.load_config()
    config['general']['training']['batch_size'] = 1
    config['general']['evaluate']['batch_size'] = 1
    config['general']['observation_pool_capacity'] = 5
    config['general']['training_method'] = 'dagger'
    config['env']['task_types'] = task_types
    config['controller']['type'] = controller_type
    config['controller']['debug'] = True
    # config['env']['expert_type'] = "downward"

    if train_or_eval == "train":
        config['dataset']['data_path'] = pjoin(ALFWORLD_DATA, "json_2.1.1", "train")
    else:
        config['dataset']['data_path'] = pjoin(ALFWORLD_DATA, "json_2.1.1", "valid_seen")

    #alfred_env = getattr(alfworld.agents.environment, env_type)(config, train_eval="train")
    alfred_env = AlfworldVizEnv(config, train_eval=train_or_eval)
    env = alfred_env.init_env(batch_size=1)
    num_game = alfred_env.num_games
    env.seed(42)
    np.random.seed(42)

    episode_no = 0
    collected_data = []

    while(True):
        if episode_no >= num_game:
            break
        obs, infos = env.reset()
        print('obs shape: ', obs[0].shape)
        print(infos['feedback'][0])
        game_names = infos["extra.gamefile"]
        batch_size = len(obs)

        execute_actions = []
        for _ in range(batch_size):
            execute_actions.append("restart")

        episode_data = [[] for _ in range(batch_size)]
        still_running = [1.0 for _ in range(batch_size)]

        for step_no in range(200):
            expert_actions = []
            for b in range(batch_size):
                if "expert_plan" in infos and len(infos["expert_plan"][b]) > 0:
                    next_action = infos["expert_plan"][b][0]
                    expert_actions.append(next_action)
                else:
                    expert_actions.append("look")
            print(f'Expert: {expert_actions}')
            cmd = input(' >> ')
            if cmd == "ipdb":
                from ipdb import set_trace; set_trace()
                continue
            elif cmd == "admissible":
                print(infos["admissible_commands"][0])
                continue
            elif cmd == "reset":
                break
            execute_actions = [cmd]

            obs, _, dones, infos = env.step(execute_actions)
            dones = [float(item) for item in dones]

            print('obs shape: ', obs[0].shape)
            print(infos['feedback'][0])

            still_running = [1.0 - float(item) for item in dones]  # list of float

            # if all ended, break
            if np.sum(still_running) == 0:
                break

        time_2 = datetime.datetime.now()
        for b in range(batch_size):
            print("Episode: {:3d} | {:s} | time spent: {:s} | used steps: {:s}".format(episode_no + b, game_names[b], str(time_2 - time_1).rsplit(".")[0], str(len(episode_data[b]))))

        for b in range(batch_size):
            if len(collected_data) >= num_game:
                continue
            try:
                if len(episode_data[b]) >= 200:
                    continue
                collected_data.append({"g": "/".join(game_names[b].split("/")[-3:-1]), "g_id": episode_no, "task": task_desc_strings[b], "steps": episode_data[b]})
            except:
                pass

        # finish game
        episode_no += batch_size


if __name__ == '__main__':
    collect_data([1, 2, 3, 4, 5, 6])
