import json
import os
import argparse


class StepDataLoader(object):
    def _get_position(self, position):
        position_row = int(position['row'])
        position_col = int(position['column'])
        return [position_row, position_col]

    def _load_a_object(self, placed_object):
        position = self._get_position(placed_object['position'])
        attributes = placed_object['object']
        shape = attributes['shape']
        color = attributes['color']
        size = int(attributes['size'])
        return position, [shape, color, size], placed_object['vector']

    def _split_commands(self, target_commands, separater):
        assert len(target_commands) >= len(separater)
        steps = target_commands.split(separater)

        ret = []
        for i, step in enumerate(steps):
            if i < len(steps) - 1:
                x = step + separater
            else:
                x = step

            if len(x) > 0 and x[0] == ',':
                x = x[1:]
            ret.append(x)

        if target_commands.endswith(separater):
            assert len(ret[-1]) == 0
            ret = ret[:-1]
        return ret

    def _split_steps(self, example):
        target_commands = example['target_commands']
        verb_in_command = example['verb_in_command']
        manner = example['manner']
        if manner == 'hesitantly':
            separater = ['stay']
        elif verb_in_command != 'walk':
            separater = ['walk', verb_in_command]
        else:
            separater = ['walk']

        steps = self._split_commands(target_commands, separater[0])
        if len(separater) > 1 and separater[1] in steps[-1]:
            second_steps = self._split_commands(steps[-1], separater[1])
            steps = steps[:-1] + second_steps
        steps.append('end')
        return steps

    def _parse_output(self, example, command, step_actions, agent_position,
                      placed_objects):
        # mannar
        manner = example['manner']
        if manner == 'while zigzagging':
            manner = ''

        actions = step_actions.split(',')

        # action
        if actions[0] == 'end':
            action = actions[0]
        elif manner == 'hesitantly':
            action = actions[-2]
        else:
            action = actions[-1]

        # direction
        if actions[0] == 'end':
            direction = 0
        else:
            direction = 0
            for act in actions:
                if act == 'turn left':
                    direction += 1
                elif act == 'turn right':
                    direction += 3
            direction = direction % 4
        return direction, action, manner

    def _get_a_step(self, example, command, step_actions, agent_position,
                    placed_objects, is_first):
        x = [command, agent_position, placed_objects, is_first]
        y = self._parse_output(example, command, step_actions, agent_position,
                               placed_objects)
        return x, y

    def _move(self, row, col, direction, forward):
        if direction == 0:
            col += forward
        elif direction == 1:
            row -= forward
        elif direction == 2:
            col -= forward
        elif direction == 3:
            row += forward
        return row, col

    def _get_target_pos(self, row, col, placed_objects):
        for i, placed_object in enumerate(placed_objects):
            if placed_object[0][0] == row and placed_object[0][1] == col:
                return i
        assert False

    def _update_placed_objects(self, placed_objects, new_direction, row, col,
                               forward, partial_move):
        index = self._get_target_pos(row, col, placed_objects)
        if placed_objects[index][1][2] >= 3:
            if not partial_move:
                return placed_objects, row, col, True
            else:
                partial_move = False
        row, col = self._move(row, col, new_direction, forward)
        placed_objects = [x for x in placed_objects]
        target = placed_objects[index]
        updated = list(target)
        updated[0] = [row, col]
        placed_objects[index] = tuple(updated)
        return placed_objects, row, col, partial_move

    def _simulate(self, agent_position, placed_objects, direction, action,
                  partial_move):
        if action == 'end':
            return agent_position, placed_objects, partial_move
        row, col, old_direction = agent_position
        new_direction = (old_direction + direction) % 4
        if action == 'walk':
            row, col = self._move(row, col, new_direction, 1)
        elif action == 'push':
            placed_objects, row, col, partial_move = self._update_placed_objects(
                placed_objects, new_direction, row, col, 1, partial_move)
        elif action == 'pull':
            placed_objects, row, col, partial_move = self._update_placed_objects(
                placed_objects, new_direction, row, col, -1, partial_move)
        else:
            assert False
        new_pos = [row, col, new_direction]
        return new_pos, placed_objects, partial_move

    def _load_a_episode(self, example):
        # load
        command = example['command']
        situation = example['situation']
        agent_position = self._get_position(situation['agent_position'])
        agent_position += [0]
        placed_objects = [self._load_a_object(o) for _, o in
                          situation['placed_objects'].items()]

        # separate steps
        steps = self._split_steps(example)
        ret = []
        partial_move = False
        for i, step_actions in enumerate(steps):
            x, y = self._get_a_step(example, command, step_actions,
                                    agent_position, placed_objects, i == 0)
            ret.append([x, y])

            # simulate
            direction, action, _ = y
            sim = self._simulate(agent_position, placed_objects, direction,
                                 action, partial_move)
            agent_position, placed_objects, partial_move = sim
        return ret

    def _load_a_split(self, examples):
        data = []
        lengths = []
        for example in examples:
            samples = self._load_a_episode(example)
            data.extend(samples)
            lengths.append(len(samples))
        return data, lengths

    def load(self, fn):
        with open(fn, 'r') as f:
            data = json.load(f)
        print(fn + ' is loaded.')
        self.grid_size = data['grid_size']
        self.data_splits = {}
        self.split_lengths = {}
        for k, v in data['examples'].items():
            split, lengths = self._load_a_split(v)
            print(k, len(split), len(lengths))
            if len(split) > 0:
                self.data_splits[k] = split
                self.split_lengths[k] = lengths


class DigitalStepDataLoader(StepDataLoader):
    def _convert_sample(self, sample):
        x, y = sample
        command, agent, world, is_first = x
        command = command.split(',')

        # command
        command_index = [self.word_index[w] for w in command]
        command_length = len(command)

        # world
        world_pos = [obj[0] for obj in world]
        world_index = [obj[2] for obj in world]
        world_length = len(world)

        is_first_index = int(is_first)

        rx = [command_index, agent, world_pos, world_index, is_first_index]
        ry = [y[0], self.action_index[y[1]], self.manner_index[y[2]]]
        return [*rx, ry], command_length, world_length

    def _add_dict(self, x, d):
        if x not in d:
            d[x] = len(d)

    def load(self, fn, outfile):
        super().load(fn)

        print('Building index.')
        self.word_index = {}
        self.action_index = {}
        self.manner_index = {}
        for x, y in self.data_splits['train']:
            for word in x[0].split(','):
                self._add_dict(word, self.word_index)
            self._add_dict(y[1], self.action_index)
            self._add_dict(y[2], self.manner_index)

        print('Converting formats.')
        max_clen = 0
        max_wlen = 0
        digit_splits = {}
        for k, v in self.data_splits.items():
            print(k)
            converted_data = []
            for sample in v:
                spl, clen, wlen = self._convert_sample(sample)
                converted_data.append(spl)
                max_clen = max(max_clen, clen)
                max_wlen = max(max_wlen, wlen)
            digit_splits[k] = converted_data

        print('Converted formats.')
        del self.data_splits
        self.data_splits = digit_splits

        all_data = {}
        all_data['data_splits'] = digit_splits
        all_data['split_lengths'] = self.split_lengths

        self.data_stats = {}
        self.data_stats['max_command_length'] = max_clen
        self.data_stats['max_world_length'] = max_wlen
        self.data_stats['grid_size'] = self.grid_size
        self.data_stats['word_index'] = self.word_index
        self.data_stats['action_index'] = self.action_index
        self.data_stats['manner_index'] = self.manner_index
        all_data['data_stats'] = self.data_stats

        print('Saving to file ' + outfile + '.')
        directory = os.path.dirname(outfile)
        if not os.path.exists(directory):
            os.makedirs(directory)
        with open(outfile, 'w') as f:
            json.dump(all_data, f)
        print('Saved.')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--length', action='store_true', default=False,
                        help='Use target length split.')
    args = parser.parse_args()

    if args.length:
        input_file = 'groundedSCAN/data/target_length_split/dataset.txt'
        output_file = 'processed_data/target_length.txt'
    else:
        input_file = 'groundedSCAN/data/compositional_splits/dataset.txt'
        output_file = 'processed_data/compositional.txt'

    dl = DigitalStepDataLoader()
    dl.load(input_file, output_file)
