"""Load episodes from RL agents either from individual files or folders."""

import multiprocessing
import os
import queue
import random
import cv2
import numpy as np
from zipfile import BadZipFile
from human_load import human_data_generator

def load_single_file(file_path, verbose=True, return_rewards=False, resize_to_84=False, return_actions_only=False):
    """Load an array of images and actions from a single NumPy file"""
    data = np.load(file_path)
    if verbose:
        print('File', file_path, 'loaded')
    if return_actions_only:
        return data['action']
    # Convert the images from RGB to grayscale if they aren't already grayscale
    obs = data['observ']
    if len(obs.shape) == 4:
        obs = (obs[:, :, :, 0] * 0.299) + (obs[:, :, :, 1] * 0.587) + (obs[:, :, :, 2] * 0.114)
    # Optionally downscale the images to 84x84
    if resize_to_84:
        full_obs = obs
        obs = np.zeros(shape=(full_obs.shape[0], 84, 84))
        for i in range(full_obs.shape[0]):
            obs[i] = cv2.resize(full_obs[i], (84, 84))
    # Optionally return the rewards in addition to the observations and actions
    if return_rewards:
        return obs, data['action'], data['reward']
    else:
        return obs, data['action']

def data_generator(filename_queue, action_space, return_sequences=True, return_filenames=False, return_rewards=False, resize_to_84=False, return_counts_and_filenames_only=False, return_scalars_and_filenames_only=False, color_bins=None, return_episode_metadata=False, discrete_values=None, use_downscaled_files=False, use_human_data=False, human_data_game=None, human_cycle=True):
    """Load images and actions in real time with a generator, from any number of folders."""
    # If human data loading is enabled, create a generator
    if use_human_data:
        human_data_gen = human_data_generator(human_data_game, resize_to_84=resize_to_84, cycle=human_cycle)
    # Loop forever, generating data
    while True:
        # If human data loading is enabled, get data from that
        if use_human_data:
            try:
                all_images, all_rewards, all_actions = next(human_data_gen)
            except StopIteration:
                print('Human data generator has exited.')
                break
        else:
            # Keep trying to load files until one is successfully loaded; there may be corrupt files
            while True:
                try:
                    # Pop a few file names off the end of the queue and load them
                    try:
                        paths = [filename_queue.get(True, 100) for _ in range(1)]
                        # Optionally use downscaled files instead
                        if use_downscaled_files:
                            configs, file_ids = zip(*[path.split('/')[-2:] for path in paths])
                            paths = [os.path.expanduser(f'~/ais/{config}_downscaled/{file_id}' if 'downscaled' not in config else f'~/ais/{config}/{file_id}') for config, file_id in zip(configs, file_ids)]

                    # If we hit the end of the list for 100 seconds, stop iteration
                    except queue.Empty:
                         print('Hit end of filename generator.')
                         yield None
                         return
                    all_images, all_actions, all_rewards = zip(*(load_single_file(path, verbose=False, return_rewards=True, resize_to_84=resize_to_84) for path in paths))
                    break
                except (BadZipFile, FileNotFoundError, ValueError) as error:
                    print('Failed to load file; error', error)

        if return_counts_and_filenames_only:
            yield all_actions[0].shape[0], paths
            continue

        if return_scalars_and_filenames_only:
            images = discretize_to_1d(all_images[0], color_bins=color_bins, discrete_values=discrete_values)
            scalars = [encode(image, discrete_values) for image in images]
            yield scalars, paths
            continue

        # The network's inputs are going to be: a sequence of observations, and a sequence of actions taken immediately before those corresponding observations
        sequence_length = 100
        x_observations = []
        x_actions = []
        rewards = []
        sequence_lengths = []
        total_reward = 0
        episode_length = 0
        # The outputs of the network are also actions, offset by one after the input actions
        y_actions = []
        for images, actions, rewards_loop in zip(all_images, all_actions, all_rewards):
            total_reward += np.sum(rewards_loop)
            episode_length += len(rewards_loop)
            # If there are no sequences then just add the whole episode to the lists
            if not return_sequences:
                x_observations.append(images[:-1])
                x_actions.append(actions[:-1])
                y_actions.append(actions[1:])
                rewards.append(rewards_loop[:-1])
            # Otherwise, take sequence_length of observations actions for the input, and the following actions for the outputs
            else:
                for i in range(0, len(images) - sequence_length - 1, sequence_length):
                    x_observations.append(images[(i):(i + sequence_length)])
                    x_actions.append(actions[(i):(i + sequence_length)])
                    y_actions.append(actions[(i + 1):(i + sequence_length + 1)])
                    rewards.append(rewards_loop[(i):(i + sequence_length)])
                    sequence_lengths.append(sequence_length)

        # Convert X and Y to NumPy arrays
        x_observations = np.array(x_observations)
        x_actions = np.array(x_actions)
        y_actions = np.array(y_actions)
        rewards = np.array(rewards)
        # Add a 1-length channel dimension to the image observations
        x_observations = np.expand_dims(x_observations, axis=-1)

        # One-hot encode the categorical action data
        def one_hot(numeric):
            array = np.zeros(shape=(numeric.shape[0], numeric.shape[1], action_space))
            for i in range(array.shape[0]):
                for j in range(array.shape[1]):
                    array[i, j, numeric[i, j]] = 1
            return array
        x_actions = one_hot(x_actions)
        y_actions = one_hot(y_actions)

        # Yield the observations, input actions, output actions, and optionally filenames and rewards
        return_values = [(x_observations, x_actions), y_actions]
        if return_filenames:
            return_values.append(paths)
        if return_rewards:
            return_values.append(rewards)
        if return_episode_metadata:
            return_values.append((total_reward, episode_length, sequence_lengths))
        yield return_values

def filename_generator(directory_index, random_and_repeat):
    """Generate filenames from an index on disk."""
    # Load the contents of the index file, and extract the paths
    files = list(reversed(sorted(line.split('f:')[1].strip() for line in directory_index.readlines() if line.startswith('f:'))))
    # Shuffle the list if this feature is enabled
    if random_and_repeat:
        random.shuffle(files)
    # Loop forever, generating data
    while True:
        # Yield filenames one at a time
        try:
            yield files.pop()
        except IndexError:
            print('Ran out of filenames')
            break

def parallel_generator(directory_index, num_parallel=1, random_and_repeat=True, **kwargs):
    """Run multiple generators in parallel as fast as possible and return the results in some order."""

    def fire_hose(generator, queue):
        """Take a generator and put stuff as quickly as possible into a queue."""
        for stuff in generator:
            queue.put(stuff)

    # Create a filename queue and start a process to load filenames into that queue
    filename_queue = multiprocessing.Queue(16)
    filename_gen = filename_generator(directory_index, random_and_repeat)
    multiprocessing.Process(target=fire_hose, args=(filename_gen, filename_queue), daemon=True).start()

    # Create a queue and initiate the processes to load data into the queue as quickly as possible
    data_queue = multiprocessing.Queue(16)
    for i in range(num_parallel):
        print('Starting generator process', i)
        data_gen = data_generator(filename_queue, **kwargs)
        multiprocessing.Process(target=fire_hose, args=(data_gen, data_queue), daemon=True).start()
    # Continuously load from the queue
    while True:
        print('Filenames:', filename_queue.qsize())
        print('Data:', data_queue.qsize())
        value = data_queue.get()
        # If we find None in the queue, the filename generator has stopped and a timeout has occurred
        if value is None:
            return
        else:
            yield value

def downscale(images, size):
    """Downscale an array of images and take the mean of the channels."""
    # Convert the images to floating-point numbers
    images = images.astype(np.float32)
    # First, downscale the images to the squared size (changing the aspect ratio)
    square_size = size ** 2
    downscaled_images = np.zeros(shape=(images.shape[0], square_size, square_size))
    for i in range(images.shape[0]):
        downscaled_images[i] = cv2.resize(images[i], (square_size, square_size))
    images = downscaled_images
    # Take the mean over each size * size block of that image to downscale it to size * size
    images = images.reshape((-1, size, size, size, size))
    images = images.transpose((0, 1, 3, 2, 4))
    images = images.reshape((-1, size, size, square_size))
    images = images.mean(-1)
    return images

def discretize(images, discrete_values, return_bins=False, color_bins=None):
    """Given an array of images, discretize them with a relatively even distribution of discrete values."""
    size = images.shape[1]
    # If color bins are not provided, calculate them
    if color_bins is None:
        # The color bins should each contain an equal number of values
        bin_percentiles = np.linspace(0, 100, discrete_values, endpoint=False)[1:]
        color_bins = np.zeros(shape=(size, size, bin_percentiles.shape[0]))
        for x in range(size):
            for y in range(size):
                # Calculate the percentiles using unique values only, as we want to overrepresent rarely occurring values
                unique_pixel_values = np.unique(images[:, x, y])
                color_bins[x, y] = np.percentile(unique_pixel_values, bin_percentiles)
    # Each of the pixels must be discretized separately, as NumPy does not support separate bins for each pixel
    discrete_images = np.zeros(shape=images.shape, dtype=int)
    for x in range(size):
        for y in range(size):
            discrete_images[:, x, y] = np.digitize(images[:, x, y], color_bins[x, y])
    # Optionally return the color bins along with the discretized images
    if return_bins:
        return discrete_images, color_bins
    else:
        return discrete_images

def discretize_to_1d(images, color_bins, discrete_values):
    """A function to downscale and flatten an array of images."""
    discrete_images = discretize(images, discrete_values=discrete_values, color_bins=color_bins)
    return np.reshape(discrete_images, (discrete_images.shape[0], -1))

def encode(vec, discrete_values):
    """A function to create a positional scalar encoding for an image vector."""
    # Don't use NumPy so that we can have arbitrarily large integers without overflow
    encoding = 0
    for i, value in enumerate(vec.tolist()):
        encoding += (discrete_values ** i) * value
    return encoding
