import datetime
import os
import sys
import random
import time
import copy
import pickle
import glob
from os.path import join as pjoin
import atexit
from tqdm import tqdm

import numpy as np
import blosc2

from alfworld.info import ALFWORLD_DATA
import alfworld.agents.environment
import alfworld.agents.modules.generic as generic

from promptrl.envs.alfworldviz.alfworld_viz_env import AlfworldVizEnv

os.environ["TOKENIZERS_PARALLELISM"] = "false"

BATCH_SIZE = 1
# 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar'
controller_type = 'viz'
train_or_eval = "eval_out_of_distribution"
OUTPUT_DIR = 'data_fixed/img_seq2seq_data_ood'

def collect_data(task_types):

    time_1 = datetime.datetime.now()
    config = generic.load_config()
    config['general']['training']['batch_size'] = BATCH_SIZE
    config['general']['evaluate']['batch_size'] = 1
    config['general']['observation_pool_capacity'] = 5
    config['general']['training_method'] = 'dagger'
    config['env']['task_types'] = task_types
    config['controller']['type'] = controller_type
    # config['env']['expert_type'] = "downward"

    #alfred_env = getattr(alfworld.agents.environment, env_type)(config, train_eval="train")
    alfred_env = AlfworldVizEnv(config, train_eval=train_or_eval)
    env = alfred_env.init_env(batch_size=BATCH_SIZE)
    def close_env():
        print('Shutting down ai2thor environment')
        env.close()
    atexit.register(close_env)
    num_game = alfred_env.num_games
    env.seed(42)
    np.random.seed(42)

    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    episode_no = 0
    collected_data = []
    pbar = tqdm(total=num_game)

    while(True):
        if episode_no >= num_game:
            break
        obs, infos = env.reset()
        game_names = infos["extra.gamefile"]
        batch_size = len(obs)

        execute_actions = []
        prev_step_dones = []
        episode_data = []
        still_running = [1.0 for _ in range(batch_size)]
        for i in range(batch_size):
            execute_actions.append("restart")
            prev_step_dones.append(0.0)

            intro, obs_oracle, goal = infos['feedback'][i].split('\n\n')
            episode_data.append({
                'g': '/'.join(game_names[i].split('/')[-3:-1]),
                'g_id': episode_no,
                'goal': goal,
                'oracle': [obs_oracle],
                'obs': [],
                'actions': [],
                'admissible': [],
                'success': False
            })

        for step_no in range(200):
            # get most recent k observations
            expert_actions = []
            for b in range(batch_size):
                if "expert_plan" in infos and len(infos["expert_plan"][b]) > 0:
                    next_action = infos["expert_plan"][b][0]
                    expert_actions.append(next_action)
                else:
                    expert_actions.append("look")
            execute_actions = expert_actions

            for b in range(batch_size):
                if still_running[b] == 0:
                    continue
                if step_no > 0:
                    episode_data[b]['oracle'].append(infos['feedback'][b])
                episode_data[b]['obs'].append(blosc2.pack_array(obs[b]))
                episode_data[b]['actions'].append(execute_actions[b])
                episode_data[b]['success'] = episode_data[b]['success'] or infos['won'][b]
                episode_data[b]['admissible'].append(infos['admissible_commands'][b])

            obs, _, dones, infos = env.step(execute_actions)
            dones = [float(item) for item in dones]

            still_running = [1.0 - float(item) for item in dones]  # list of float

            # if all ended, break
            if np.sum(still_running) == 0:
                break

        time_2 = datetime.datetime.now()
        for b in range(batch_size):
            print("Episode: {:3d} | {:s} | time spent: {:s} | used steps: {:s}".format(episode_no + b, game_names[b], str(time_2 - time_1).rsplit(".")[0], str(len(episode_data[b]['obs']))))

        for b in range(batch_size):
            if len(collected_data) >= num_game:
                continue
            collected_data.append(episode_data[b])

        # finish game
        episode_no += batch_size
        pbar.update(batch_size)

    with open(OUTPUT_DIR + "/tw_alfred_seq2seq_" + train_or_eval + "_task" + "-".join([str(item) for item in task_types]) + "_hc.pkl", 'wb') as f:
        pickle.dump(collected_data, f)


if __name__ == '__main__':
    for task in [[1], [2], [3], [4], [5], [6]]:
        collect_data(task)
