import json
import random


def get_data_loader(name, args):
    if name == 'step':
        return DigitalStepDataLoader(args)
    elif name == 'interactive':
        return InteractiveStepDataLoader(args)
    elif name == 'rotation':
        return RotationStepDataLoader(args)
    assert False


class DigitalStepDataLoader(object):
    def __init__(self, args):
        self.args = args

    def get_data(self, name, batch_size=-1):
        data_split = self.data_splits[name]
        if batch_size >= 0:
            data_split = random.choices(data_split, k=batch_size)
        data_split = list(zip(*data_split))
        return data_split

    def get_train_data(self, batch_size):
        return self.get_data('train', batch_size)

    def get_test_data(self, name):
        return self.get_data(name), self.split_lengths[name]

    def get_names(self):
        return list(self.data_splits.keys())

    def get_data_stats(self):
        return self.data_stats

    def load_saved(self, fn):
        with open(fn, 'r') as f:
            self.all_data = json.load(f)
        print(fn + ' is loaded.')
        self.data_splits = self.all_data['data_splits']
        self.split_lengths = self.all_data['split_lengths']
        self.data_stats = self.all_data['data_stats']

        self.command_pad = len(self.data_stats['word_index'])
        self.data_stats['word_index']['<PAD>'] = self.command_pad

        example = self.data_splits['train'][0]
        self.position_pad = [0] * len(example[2][0])
        self.world_pad = [0] * len(example[3][0])

        processed_data = {}
        for k, v in self.data_splits.items():
            data_split = []
            for sample in v:
                ps = self._pad_a_sample(sample)
                if k == 'train' and ps[-1][-1] == 3:  # cautiously
                    for _ in range(self.args.k_shot):
                        data_split.append(ps)
                else:
                    data_split.append(ps)
            processed_data[k] = data_split
            print('processed', k)
        del self.data_splits
        self.data_splits = processed_data

    def _pad_a_slot(self, x, index, max_length, pad):
        slot = x[index]
        if len(slot) < max_length:
            slot += [pad] * (max_length - len(slot))
            x[index] = slot

    def _pad_a_sample(self, sample):
        command_length = len(sample[0])
        world_length = len(sample[2])
        assert world_length == len(sample[3])

        self._pad_a_slot(sample, 0, self.data_stats['max_command_length'],
                         self.command_pad)
        self._pad_a_slot(sample, 2, self.data_stats['max_world_length'],
                         self.position_pad)

        sample[3] = [[int(b) for b in a] for a in sample[3]]
        self._pad_a_slot(sample, 3, self.data_stats['max_world_length'],
                         self.world_pad)
        return [*sample[:-1], command_length, world_length, sample[-1]]


class InteractiveStepDataLoader(DigitalStepDataLoader):
    def _pad_a_sample(self, sample):
        ret = super()._pad_a_sample(sample)
        y_action = ret[-1][1]
        x_action = ret[0][0]
        if y_action == self.data_stats['action_index']['end']:
            if x_action == self.data_stats['word_index']['push']:
                ret[-1][1] = self.data_stats['action_index']['push']
            elif x_action == self.data_stats['word_index']['pull']:
                ret[-1][1] = self.data_stats['action_index']['pull']
        return ret


class RotationStepDataLoader(InteractiveStepDataLoader):
    def _rotate(self, row, col, direction):
        if direction == 0:
            r_row = row
            r_col = col
        elif direction == 1:
            r_row = col
            r_col = -row
        elif direction == 2:
            r_row = -row
            r_col = -col
        elif direction == 3:
            r_row = -col
            r_col = row
        else:
            assert False
        return r_row, r_col

    def _pad_a_sample(self, sample):
        ret = super()._pad_a_sample(sample)

        # agent
        arow, acol, direction = ret[1]

        # objects
        rotated = []
        for row, col in ret[2]:
            rotated.append(self._rotate(row - arow, col - acol, direction))

        ret[1] = self.zero_agent
        ret[2] = rotated
        return ret

    def load_saved(self, fn):
        self.zero_agent = [0, 0, 0]
        super().load_saved(fn)


if __name__ == '__main__':
    dl = RotationStepDataLoader()
    dl.load_saved('processed_data/compositional.txt')
    print(len(dl.get_train_data(100)))
