import os
import numpy as np
import json

from train import Config
from train.task import Task, TaskType


def create_action(action_dict=None):
    def noop():
        return {
            "attack": 0,
            "back": 0,
            "camera": np.array([0.0, 0.0]),
            "craft": 0,  # Enum(none,torch,stick,planks,crafting_table)
            "equip": 0,  # Enum(none,air,wooden_axe,wooden_pickaxe,stone_axe,stone_pickaxe,iron_axe,iron_pickaxe)
            "forward": 0,
            "jump": 0,
            "left": 0,
            "nearbyCraft": 0,  # Enum(none,wooden_axe,wooden_pickaxe,stone_axe,stone_pickaxe,iron_axe,iron_pickaxe,furnace)
            "nearbySmelt": 0,  # Enum(none,iron_ingot,coal)
            "place": 0,  # Enum(none,dirt,stone,cobblestone,crafting_table,furnace,torch)
            "right": 0,
            "sneak": 0,
            "sprint": 0
        }
    action = noop()
    if action_dict is not None:
        for k, v in action_dict.items():
            action[k] = v
    return action


class TaskBuilder:
    def __init__(self, config, use_transition_checkpoints=False, use_imitation_learning=False):
        self.config = config
        self.collect_data_dir_target = os.path.join(self.config.debug.subtask_outputdir,
                                                    self.config.debug.subtask_dir_target)
        self.collect_data_dir_transition = os.path.join(self.config.debug.subtask_outputdir,
                                                        self.config.debug.subtask_dir_transition)

        self.data_id = self.config.debug.env
        self.transition_corrections = self.config.debug.subtask_transition_corrections
        self.insert_imitation_actions = self.config.debug.subtask_insert_imitation_actions
        self.use_transition_checkpoints = use_transition_checkpoints
        self.use_imitation_learning = use_imitation_learning

    def imitation(self, id, target):
        task = Task(self.config,
                    id=id,
                    target=target,
                    task_type=TaskType.Imitation)
        task.data_dir_target = self.collect_data_dir_target
        task.data_dir_transition = self.collect_data_dir_transition
        task.data_id = self.data_id
        task.imitation_ready = False
        task.transition_data_ready = True
        task.target_data_ready = True
        return task

    def collect(self, id, target):
        task = Task(self.config,
                    id=id,
                    target=target,
                    task_type=TaskType.Learning)
        task.data_dir_target = self.collect_data_dir_target
        task.data_dir_transition = self.collect_data_dir_transition
        task.data_id = self.data_id
        task.imitation_ready = False
        task.transition_data_ready = True
        task.target_data_ready = True
        return task

    def craft(self, id, target):
        task = Task(self.config,
                    id=id,
                    target=target,
                    task_type=TaskType.Imitation)
        task.data_dir_target = self.collect_data_dir_target
        task.data_dir_transition = self.collect_data_dir_transition
        task.data_id = self.data_id
        task.transition_data_ready = True
        task.target_data_ready = True
        task.set_imitation_actions([create_action({'craft': target})])
        return task

    def craft_place(self, id, target):
        task = Task(self.config,
                    id=id,
                    target=target,
                    task_type=TaskType.Imitation)
        task.data_dir_target = self.collect_data_dir_target
        task.data_dir_transition = self.collect_data_dir_transition
        task.data_id = self.data_id
        task.transition_data_ready = True
        task.target_data_ready = True
        task.set_imitation_actions([create_action({'craft': target}),
                                    create_action({'place': target})])
        return task

    def place_nearby_craft_equip(self, id, target):
        task = Task(self.config,
                    id=id,
                    target=target,
                    task_type=TaskType.Imitation)
        task.data_dir_target = self.collect_data_dir_target
        task.data_dir_transition = self.collect_data_dir_transition
        task.data_id = self.data_id
        task.transition_data_ready = True
        task.target_data_ready = True
        task.set_imitation_actions([create_action({'place': 'crafting_table'}),
                                    create_action({'nearbyCraft': target}),
                                    create_action({'equip': target})])
        return task

    def place_nearby_craft_place(self, id, target):
        task = Task(self.config,
                    id=id,
                    target=target,
                    task_type=TaskType.Imitation)
        task.data_dir_target = self.collect_data_dir_target
        task.data_dir_transition = self.collect_data_dir_transition
        task.data_id = self.data_id
        task.transition_data_ready = True
        task.target_data_ready = True
        task.set_imitation_actions([create_action({'place': 'crafting_table'}),
                                    create_action({'nearbyCraft': target}),
                                    create_action({'place': target})])
        return task

    def nearby_smelt(self, id, target):
        task = Task(self.config,
                    id=id,
                    target=target,
                    task_type=TaskType.Imitation)
        task.data_dir_target = self.collect_data_dir_target
        task.data_dir_transition = self.collect_data_dir_transition
        task.data_id = self.data_id
        task.transition_data_ready = True
        task.target_data_ready = True
        task.set_imitation_actions([create_action({'nearbySmelt': target})])
        return task

    def extract_dummy_tasks(self, consensus):
        decoding = self.config.subtask.consensus_code
        tasks = []
        consensus_names = []
        for i, ch in enumerate(consensus):
            name = decoding[ch]
            target = name
            if not self.use_transition_checkpoints:
                id = '{}'.format(name)
            else:
                if i == 0:
                    id = 'None-{}'.format(name)
                elif i == len(consensus)-1:
                    id = '{}-End'.format(name)
                else:
                    prev_name = decoding[consensus[i-1]]
                    id = '{}-{}'.format(prev_name, name)

            # collection
            if not self.use_imitation_learning:
                task = self.collect(id=id, target=target)
            # create empty imitation task objects
            elif self.insert_imitation_actions:
                # according to consensus
                if ch in ["E", "N", "P", "V"]:
                    task = self.imitation(id, target)
                else:
                    task = self.collect(id=id, target=target)
            # pre-fill imitation actions for debugging according to the data statistics
            else:
                if ch in ["F"]:
                    task = self.place_nearby_craft_place(id=id, target=target)
                elif ch in ["W", "E", "C", "Y", "R", "N"]:
                    task = self.place_nearby_craft_equip(id=id, target=target)
                elif ch in ["H"]:
                    task = self.craft_place(id=id, target=target)
                elif ch in ["L", "P", "V"]:
                    task = self.craft(id=id, target=target)
                else:
                    task = self.collect(id=id, target=target)
            consensus_names.append(name)
            tasks.append(task)

        # change tasks to learning due to the consensus transitions evaluated through the data statistics
        if self.transition_corrections:
            for t in tasks:
                # correct according to the consensus transitions to learning
                if t.id == "wooden_pickaxe-crafting_table" \
                        or t.id == 'stone_pickaxe-crafting_table':
                    t.task_type = TaskType.Learning
                    t.target_data_ready = True
                    t.transition_data_ready = True
                    t.imitation_ready = False
        return tasks, consensus_names

    def build_task(self, data, id, target):
        if data['type'] == 'imitation':
            task_type = TaskType.Imitation
        elif data['type'] == 'learning':
            task_type = TaskType.Learning
        else:
            raise NotImplementedError('Unknown task type!')

        task = Task(self.config,
                    id=id,
                    target=target,
                    task_type=task_type)
        if 'data_dir_target' in data:
            task.data_dir_target = data['data_dir_target']
        if 'data_dir_transition' in data:
            task.data_dir_transition = data['data_dir_transition']
        if 'data_id' in data:
            task.data_id = data['data_id']
        if 'actions' in data:
            task.set_imitation_actions(data['actions'])
            task.imitation_ready = True
        if 'transition_data_ready' in data:
            task.transition_data_ready = data['transition_data_ready']
        if 'target_data_ready' in data:
            task.target_data_ready = data['target_data_ready']
        return task

    def extract_json_tasks(self, consensus):
        decoding = self.config.subtask.consensus_code
        encoding = {v: k for k, v in decoding.items()}
        tasks = []
        consensus_names = []

        with open(self.config.subtask.generated_profile) as file:
            data = json.load(file)
        for i, ch in enumerate(consensus):
            definitions = data['definitions']
            if ch in definitions:
                t = definitions[ch]
                target = t['target']
                if not self.use_transition_checkpoints:
                    id = '{}'.format(target)
                else:
                    if i == 0:
                        id = 'None-{}'.format(target)
                    elif i == len(consensus)-1:
                        id = '{}-End'.format(target)
                    else:
                        prev_target = decoding[consensus[i-1]]
                        id = '{}-{}'.format(prev_target, target)

                task = self.build_task(t, id, target)

                consensus_names.append(target)
                tasks.append(task)

        if self.transition_corrections:
            tasks_update = []
            consensus_names_update = []
            for t in tasks:
                transitions = data['transitions']
                # correct according to consensus transitions
                code = [encoding[c] for c in t.id.split('-') if c not in ['None', 'End']]
                t_code = ''.join(code)
                if t_code in transitions:
                    d = transitions[t_code]
                    target = d['target']
                    t = self.build_task(d, t.id, target)
                tasks_update.append(t)
                consensus_names_update.append(t.target)
            tasks = tasks_update
            consensus_names = consensus_names_update
        return tasks, consensus_names


if __name__ == '__main__':
    config = Config('configs/experiment/config.meta.json')
    consensus = config.debug.subtask_consensus_demos[
        config.debug.subtask_consensus_selection]
    tb = TaskBuilder(config)
    tasks, _ = tb.extract_json_tasks(consensus)

    assert len(tasks) == len(consensus)
