import itertools

import numpy as np
from torch.utils import data


class DelayCopyDataTiming(data.Dataset):
    """Delay-copy task data generator."""

    def __init__(self, dt=1., seed=3000, length=10, width=3, initial_delay=0,
                 delay=100, batch_size=10, input_spacing_factor=0.1, binary_encoding=True, output_repeat_factor=1):
        """Initializes the data generator.
        """

        self.dt = dt
        self.rng = np.random.default_rng(seed)
        self.length = length
        self.width = width
        self.initial_delay = initial_delay
        # self.initial_delay_fixed_length = initial_delay_fixed_length
        self.delay = delay
        self.input_spacing_factor = input_spacing_factor
        # self.delay_fixed_length = delay_fixed_length
        self.batch_size = int(batch_size)
        self.binary_encoding = binary_encoding
        self.output_repeat_factor = output_repeat_factor

    def __getitem__(self, index):
        """
        """
        if self.binary_encoding:
            if self.width <= 8:
                tmp_width = self.width
            else:
                tmp_width = 8

            str_ = np.unpackbits(self.rng.choice(range(1, 2 ** tmp_width), self.length).astype(np.uint8).reshape(-1, 1),
                                 axis=1)[:, -tmp_width:]
            if self.width > 8:
                str_2 = np.unpackbits(
                    self.rng.choice(range(0, 2 ** (self.width - tmp_width)), self.length).astype(np.uint8).reshape(-1,
                                                                                                                   1),
                    axis=1)[:, -(self.width - tmp_width):]
                str_ = np.concatenate((str_, str_2), axis=1)
        else:  # One-hot encoding
            symbols = self.rng.integers(0, self.width, size=(1, self.length))
            str_ = np.zeros((self.length, self.width))
            str_[np.arange(self.length), symbols] = 1

        input_times = []
        inputs = []
        target_times = []
        targets = []

        # if self.initial_delay_fixed_length:
        delay_before_str = self.initial_delay
        # else:
        #     delay_before_str = self.rng.choice(range(1, self.initial_delay + 1))

        for ii in range(self.length):
            input_times.append(delay_before_str + ii * self.dt)
            inputs.append(np.concatenate((str_[ii], np.zeros(1))))

        # if self.delay_fixed_length:
        delay_after_str = self.delay
        # else:
        #     delay_after_str = self.rng.choice(range(1, self.delay + 1))

        # Add recall symbol
        input_times.append(delay_before_str + self.length * self.dt * self.input_spacing_factor + delay_after_str)
        inputs.append(np.concatenate((np.zeros(self.width), np.ones(1))))

        for ii in range(self.length * self.output_repeat_factor):
            target_times.append(
                delay_before_str + self.length * self.dt * self.input_spacing_factor + delay_after_str + self.dt * self.input_spacing_factor + ii * self.dt * self.input_spacing_factor)
        for ll in range(self.length):
            targets.append(str_[ll])

        episode = {}
        episode['input_times'] = np.stack(input_times).astype(np.float)
        episode['inputs'] = np.stack(inputs).astype(np.float)
        episode['target_times'] = np.stack(target_times).astype(np.float)
        episode['targets'] = np.stack(targets).astype(np.float)

        return episode

    def __len__(self):  # denotes the total number of samples
        return 100000 * self.batch_size


def get_delay_timing_data_full_batch(dt, length, width, initial_delay, delay, input_spacing_factor,
                                     output_repeat_factor):
    total_possibilities = (2 ** width - 1) ** length
    all_possible_strings = list(itertools.product("01", repeat=(length * width)))
    all_inputs_list = []
    for ss in all_possible_strings:
        inps = np.empty((length, width))
        for l in range(length):
            s = ss[l * width:(l + 1) * width]
            sn = [int(c) for c in s]
            sum_ = sum(sn)
            if sum_ != 0:
                inps[l, :] = np.array(sn)
            else:
                break
        else:
            all_inputs_list.append(inps)
    assert len(all_inputs_list) == total_possibilities

    input_times_list = []
    inputs_list = []
    target_times_list = []
    targets_list = []

    for str_ in all_inputs_list:
        input_times = []
        inputs = []
        target_times = []
        targets = []

        delay_before_str = initial_delay

        for ii in range(length):
            input_times.append(delay_before_str + ii * dt * input_spacing_factor)
            assert sum(str_[ii]) != 0
            inputs.append(np.concatenate((str_[ii], np.zeros(1))))

        delay_after_str = delay

        # Add recall symbol
        input_times.append(delay_before_str + length * dt * input_spacing_factor + delay_after_str)
        inputs.append(np.concatenate((np.zeros(width), np.ones(1))))

        for ii in range(length * output_repeat_factor):
            target_times.append(
                delay_before_str + length * dt * input_spacing_factor + delay_after_str + dt * input_spacing_factor + ii * dt * input_spacing_factor)
        for ll in range(length):
            targets.append(str_[ll])

        input_times_list.append(input_times)
        inputs_list.append(inputs)
        target_times_list.append(target_times)
        targets_list.append(targets)

    input_times_list, inputs_list, target_times_list, targets_list = \
        np.stack(input_times_list), np.stack(inputs_list), np.stack(target_times_list), np.stack(targets_list)

    return input_times_list, inputs_list, target_times_list, targets_list


def get_delay_timing_onehot_data_full_batch(dt, length, width, initial_delay, delay, input_spacing_factor,
                                            output_repeat_factor):
    all_inputs_list = []
    for ll in range(length):
        for ww in range(width):
            str_ = np.zeros((length, width))
            str_[ll, ww] = 1
            all_inputs_list.append(str_)

    input_times_list = []
    inputs_list = []
    target_times_list = []
    targets_list = []

    for str_ in all_inputs_list:
        input_times = []
        inputs = []
        target_times = []
        targets = []

        delay_before_str = initial_delay

        for ii in range(length):
            input_times.append(delay_before_str + ii * dt * input_spacing_factor)
            assert sum(str_[ii]) != 0
            inputs.append(np.concatenate((str_[ii], np.zeros(1))))

        delay_after_str = delay

        # Add recall symbol
        input_times.append(delay_before_str + length * dt * input_spacing_factor + delay_after_str)
        inputs.append(np.concatenate((np.zeros(width), np.ones(1))))

        for ii in range(length * output_repeat_factor):
            target_times.append(
                delay_before_str + length * dt * input_spacing_factor + delay_after_str + dt * input_spacing_factor + ii * dt * input_spacing_factor)
        for ll in range(length):
            targets.append(str_[ll])

        input_times_list.append(input_times)
        inputs_list.append(inputs)
        target_times_list.append(target_times)
        targets_list.append(targets)

    input_times_list, inputs_list, target_times_list, targets_list = \
        np.stack(input_times_list), np.stack(inputs_list), np.stack(target_times_list), np.stack(targets_list)

    return input_times_list, inputs_list, target_times_list, targets_list


if __name__ == '__main__':
    # dt, length, width, initial_delay, delay
    input_times_list, inputs_list, target_times_list, targets_list = get_delay_timing_data_full_batch(1., 2, 3, 1, 2,
                                                                                                      0.1)
    print(input_times_list.shape)
