import datetime
import os
import sys
import random
import time
import copy
import json
import glob
import math
from os.path import join as pjoin

import numpy as np
from tqdm import tqdm

from alfworld.info import ALFWORLD_DATA
import alfworld.agents.environment
import alfworld.agents.modules.generic as generic

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def collect_data(batch_size=1):
    data = []
    time_1 = datetime.datetime.now()
    config = generic.load_config()
    config['env']['task_types'] = [1, 2, 3, 4, 5, 6]
    config['general']['training']['batch_size'] = batch_size
    config['general']['evaluate']['batch_size'] = batch_size
    total_games = 0
    for mode in ['train', 'eval_in_distribution', 'eval_out_of_distribution']:
        print(f'analyzing {mode}')
        alfred_env = getattr(alfworld.agents.environment, 'AlfredTWEnv')(config, train_eval=mode)
        env = alfred_env.init_env(batch_size=batch_size)
        num_game = alfred_env.num_games
        total_games += num_game
        env.seed(42)
        np.random.seed(42)

        i_game = 0
        for i_batch in tqdm(range(math.ceil(num_game / batch_size))):
            obs, infos = env.reset()
            game_names = infos["extra.gamefile"]
            for i, (ob, game_name) in enumerate(zip(obs, game_names)):
                if i_game >= num_game:
                    break
                i_game += 1
                goal_instr = ob.split('\n\n')[2]
                task_id, demo_id = game_name.split('/')[-3:-1]
                task_type, obj, recep2, recep1, task_idx = task_id.split('-')
                assert recep2 == 'None'

                data.append({
                    'mode': mode,
                    'demo_id': demo_id,
                    'task_id': task_id,
                    'task_idx': task_idx,
                    'task_type': task_type,
                    'goal': goal_instr,
                    'obj': obj,
                    'recep1': recep1,
                    'recep2': recep2,
                })

    assert len(data) == total_games
    print('saving...')
    with open('task_data.csv', 'w') as f:
        cols = ['mode', 'demo_id', 'task_id', 'task_idx', 'task_type', 'goal', 'obj', 'recep1', 'recep2']
        f.write(', '.join(cols) + '\n')
        for row in data:
            f.write(', '.join(row[c] for c in cols) + '\n')

if __name__ == '__main__':
    collect_data()
