from datasets import Dataset
import os
import json
import random

from tqdm import tqdm
from termcolor import colored

import textworld
import textworld.agents
import textworld.gym

from alfworld.agents.utils.misc import Demangler
from alfworld.agents.expert import HandCodedTWAgent, HandCodedAgentTimeout


TASK_TYPES = {1: "pick_and_place_simple",
              2: "look_at_obj_in_light",
              3: "pick_clean_then_place_in_recep",
              4: "pick_heat_then_place_in_recep",
              5: "pick_cool_then_place_in_recep",
              6: "pick_two_obj_and_place"}


class AlfredDemangler(textworld.core.Wrapper):

    def __init__(self, *args, shuffle=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.shuffle = shuffle

    def load(self, *args, **kwargs):
        super().load(*args, **kwargs)

        demangler = Demangler(game_infos=self._entity_infos, shuffle=self.shuffle)
        for info in self._entity_infos.values():
            info.name = demangler.demangle_alfred_name(info.id)


class AlfredInfos(textworld.core.Wrapper):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._gamefile = None

    def load(self, *args, **kwargs):
        super().load(*args, **kwargs)
        self._gamefile = args[0]

    def reset(self, *args, **kwargs):
        state = super().reset(*args, **kwargs)
        state["extra.gamefile"] = self._gamefile
        return state


# Enum for the supported types of AlfredExpert.
class AlfredExpertType:
    HANDCODED = "handcoded"
    PLANNER = "planner"


class AlfredTWEnv(object):
    '''
    Interface for Textworld Env
    '''
    def __init__(self, config, train_eval="train"):
        print("Initializing AlfredTWEnv...")
        self.config = config
        self.train_eval = train_eval

        if config["env"]["goal_desc_human_anns_prob"] > 0:
            msg = ("Warning! Changing `goal_desc_human_anns_prob` should be done with"
                   " the script `alfworld-generate`. Ignoring it and loading games as they are.")
            print(colored(msg, "yellow"))

        self.collect_game_files()

    def collect_game_files(self, verbose=False):
        def log(info):
            if verbose:
                print(info)

        self.game_files = []

        if self.train_eval == "train":
            data_path = os.path.expandvars(self.config['dataset']['data_path'])
        elif self.train_eval == "eval_in_distribution":
            data_path = os.path.expandvars(self.config['dataset']['eval_id_data_path'])
        elif self.train_eval == "eval_out_of_distribution":
            data_path = os.path.expandvars(self.config['dataset']['eval_ood_data_path'])

        log("Collecting solvable games...")

        # get task types
        assert len(self.config['env']['task_types']) > 0
        task_types = []
        for tt_id in self.config['env']['task_types']:
            if tt_id in TASK_TYPES:
                task_types.append(TASK_TYPES[tt_id])

        count = 0
        for root, dirs, files in tqdm(list(os.walk(data_path, topdown=False))):
            if 'traj_data.json' in files:
                count += 1

                # Filenames
                json_path = os.path.join(root, 'traj_data.json')
                game_file_path = os.path.join(root, "game.tw-pddl")

                if 'movable' in root or 'Sliced' in root:
                    log("Movable & slice trajs not supported %s" % (root))
                    continue

                # Get goal description
                with open(json_path, 'r') as f:
                    traj_data = json.load(f)

                # Check for any task_type constraints
                if not traj_data['task_type'] in task_types:
                    log("Skipping task type")
                    continue

                # Check if a game file exists
                if not os.path.exists(game_file_path):
                    log(f"Skipping missing game! {game_file_path}")
                    continue

                with open(game_file_path, 'r') as f:
                    gamedata = json.load(f)

                # Check if previously checked if solvable
                if 'solvable' not in gamedata:
                    print(f"-> Skipping missing solvable key! {game_file_path}")
                    continue

                if not gamedata['solvable']:
                    log("Skipping known %s, unsolvable game!" % game_file_path)
                    continue

                # Add to game file list
                self.game_files.append(game_file_path)

        print(f"Overall we have {len(self.game_files)} games in split={self.train_eval}")
        self.num_games = len(self.game_files)

        if self.train_eval == "train":
            num_train_games = self.config['dataset']['num_train_games'] if self.config['dataset']['num_train_games'] > 0 else len(self.game_files)
            self.game_files = self.game_files[:num_train_games]
            self.num_games = len(self.game_files)
            print("Training with %d games" % (len(self.game_files)))
        else:
            num_eval_games = self.config['dataset']['num_eval_games'] if self.config['dataset']['num_eval_games'] > 0 else len(self.game_files)
            self.game_files = self.game_files[:num_eval_games]
            self.num_games = len(self.game_files)
            print("Evaluating with %d games" % (len(self.game_files)))

    def get_game_logic(self):
        self.game_logic = {
            "pddl_domain": open(os.path.expandvars(self.config['logic']['domain'])).read(),
            "grammar": open(os.path.expandvars(self.config['logic']['grammar'])).read()
        }

    # use expert to check the game is solvable
    def is_solvable(self, env, game_file_path,
                    random_perturb=True, random_start=10, random_prob_after_state=0.15):
        done = False
        steps = 0
        trajectory = []
        try:
            env.load(game_file_path)
            game_state = env.reset()
            if env.expert_type == AlfredExpertType.PLANNER:
                return game_state["extra.expert_plan"]

            while not done:
                expert_action = game_state['extra.expert_plan'][0]
                random_action = random.choice(game_state.admissible_commands)

                command = expert_action
                if random_perturb:
                    if steps <= random_start or random.random() < random_prob_after_state:
                        command = random_action

                game_state, _, done = env.step(command)
                trajectory.append(command)
                steps += 1
        except Exception as e:
            print("Unsolvable: %s (%s)" % (str(e), game_file_path))
            return None

        return trajectory

def _build_data(game_files):
    data = []
    for idx, x in enumerate(game_files):
        used_for_actor, task, game_file = x
        game_file = os.path.join("./verl/environments/alfworld", game_file)
        data.append({
            "data_source": "alfworld",
            "prompt": [{"content": "prompt_with_chat_template"}],
            "extra_info": {
                "index":idx,
                "uid": f"alfworld_{idx}",
                "env_name": "alfworld",
                "task_name": task,
                "used_for_actor": used_for_actor,
                "game_file": game_file, 
                "task_id": f"alfworld_{idx}"
            }
        })
    return data


def main():
    import yaml
    import os
    def _load_config(config_file):
        assert os.path.exists(config_file), "Invalid config file"
        with open(config_file) as reader:
            config = yaml.safe_load(reader)
        return config

    env_configs = []
    config = _load_config("./alfred_config.yaml")

    alfworld_env_train = AlfredTWEnv(config, train_eval='train').game_files
    random.seed(1)
    random.shuffle(alfworld_env_train)
    alfworld_env_train = alfworld_env_train[:800]
    alfworld_env_test = AlfredTWEnv(config, train_eval='eval_out_of_distribution').game_files

    alfworld_env_train_new = []
    for game_file in alfworld_env_train:
        task = game_file.split("/")[-3]
        task = "_".join(task.split("_")[:4])
        if task.startswith("pick_and_place"):
            alfworld_env_train_new.append((True, "pick_and_place", game_file))
        elif task.startswith("pick_clean"):
            alfworld_env_train_new.append((True, "pick_clean", game_file))
        elif task.startswith("pick_cool"):
            alfworld_env_train_new.append((True, "pick_cool", game_file))
        elif task.startswith("pick_two_obj"):
            alfworld_env_train_new.append((True, "pick_two_obj", game_file))
        elif task.startswith("pick_heat"):
            alfworld_env_train_new.append((True, "pick_heat", game_file))
        else:
            alfworld_env_train_new.append((True, "look_at", game_file))

    alfworld_env_test_new = []
    for game_file in alfworld_env_test:
        task = game_file.split("/")[-3]
        task = "_".join(task.split("_")[:4])
        if task.startswith("pick_and_place"):
            alfworld_env_test_new.append((False, "pick_and_place", game_file))
        elif task.startswith("pick_clean"):
            alfworld_env_test_new.append((False, "pick_clean", game_file))
        elif task.startswith("pick_cool"):
            alfworld_env_test_new.append((False, "pick_cool", game_file))
        elif task.startswith("pick_two_obj"):
            alfworld_env_test_new.append((False, "pick_two_obj", game_file))
        elif task.startswith("pick_heat"):
            alfworld_env_test_new.append((False, "pick_heat", game_file))
        else:
            alfworld_env_test_new.append((False, "look_at", game_file))

    train_data = _build_data(alfworld_env_train_new)
    dev_data = _build_data(alfworld_env_test_new)
    
    train_dataset = Dataset.from_list(list(train_data))
    test_dataset = Dataset.from_list(list(dev_data))

    output_path = "../../../data/alfworld"
    train_dataset.to_parquet(os.path.join(output_path, 'train_small.parquet'))
    test_dataset.to_parquet(os.path.join(output_path, 'test.parquet'))

if __name__ == "__main__":
    main()