"""Preprocessing utilitis for the LASA Human Handwriting Dataset."""
import os
import numpy as np
import jax
import jax.numpy as jnp
import einops as e

import matplotlib
import matplotlib.pyplot as plt
from PIL import Image

import pyLasaDataset as lasa

matplotlib.use("Agg")


def lasa_window(traj, window_size, stride=1):
    """
    Generates sliding windows from a given trajectory.
    Args:
        traj (jnp.ndarray): The input trajectory array with shape (N, ...), where N is the number of time steps.
        window_size (int): The size of each window.
        stride (int, optional): The stride between consecutive windows. Default is 1.
    Returns:
        jnp.ndarray: An array of sliding windows with shape (num_windows, window_size, ...),
        where num_windows is the number of windows generated.
    """
    num_windows = ((traj.shape[0] - window_size) // stride) + 1
    start_idx = jnp.reshape(jnp.arange(0, num_windows * stride, stride), (-1, 1))
    return jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))(traj, start_idx, (window_size,))


lasa_window_fn = jax.vmap(jax.vmap(lasa_window, in_axes=(0, None, None)), in_axes=(0, None, None))


def process_lasa_data(dataset, window_size=12, stride=12):
    """
    Processes LASA handwriting dataset by windowing the position data.
    Args:
        dataset (str): The name of the dataset to process.
        window_size (int, optional): The size of the window for segmenting the data. Defaults to 12.
        stride (int, optional): The stride for the windowing process. Defaults to 12.
    Returns:
        tuple: A tuple containing:
            - inputs (jnp.ndarray): The starting positions segmented from the data.
            - targets (jnp.ndarray): The target positions segmented from the data.
    """
    raw_data = lasa.DataSet.__getattr__(dataset)
    demos = raw_data.demos

    # position data
    pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
    pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
    pos_data_windowed = lasa_window_fn(pos, window_size, stride)
    pos_data_windowed = e.rearrange(
        pos_data_windowed, "demonstrations points segments windows -> demonstrations segments (windows points)"
    )

    # need to segment into starting position and target position
    inputs = pos_data_windowed[:, :, :2]
    targets = pos_data_windowed[:, :, 2:]

    return inputs, targets


def process_lasa_data_lagged(dataset, window_size=12, stride=12):
    """
    Processes LASA handwriting dataset by creating lagged windows of position data.
    Args:
        dataset (str): The name of the dataset to process.
        window_size (int, optional): The size of the window to use for segmenting the data. Defaults to 12.
        stride (int, optional): The stride to use for segmenting the data. Defaults to 12.
    Returns:
        tuple: A tuple containing:
            - inputs (jnp.ndarray): The input data segmented into starting positions.
            - targets (jnp.ndarray): The target data segmented into target positions.
    """
    raw_data = lasa.DataSet.__getattr__(dataset)
    demos = raw_data.demos

    # position data
    pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
    pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
    windowed_data = []
    remainder = (pos[0, :, :].shape[-1] - window_size) % stride
    for i in range(remainder):
        pos_data_windowed = lasa_window_fn(pos[:, :, i:], window_size, stride)
        windowed_data.append(pos_data_windowed)
    windowed_data = jnp.stack(windowed_data, axis=0)
    pos_data_windowed = e.rearrange(
        windowed_data, "lags demonstrations points segments windows -> (lags demonstrations) segments (windows points)"
    )

    # need to segment into starting position and target position
    inputs = pos_data_windowed[:, :, :2]
    targets = pos_data_windowed[:, :, 2:]

    return inputs, targets


def generate_character_image(char):
    """
    Generates an image of a given character and returns it as a numpy array.
    Parameters:
    char (str): The character to be rendered into an image.
    Returns:
    numpy.ndarray: A 64x64 RGB image of the character as a numpy array.
    The function performs the following steps:
    1. Defines the character and font properties.
    2. Creates a new figure with specified dimensions.
    3. Sets the limits and turns off the axes.
    4. Adds the text with the character to the figure.
    5. Saves the figure as a temporary image file.
    6. Loads the image into a numpy array.
    7. Resizes the image to 64x64 pixels.
    8. Deletes the temporary image file.
    9. Returns the image as a numpy array.
    """
    # Define the character and font properties
    fontsize = 150

    # Create a new figure
    fig, ax = plt.subplots(figsize=(2, 2), dpi=100)

    # Set the limits and turn off the axes
    ax.set_xlim(0, 200)
    ax.set_ylim(0, 200)
    ax.axis("off")

    # Add the text with the character
    ax.text(0.5, 0.5, char, fontsize=fontsize, ha="center", va="center", fontweight="bold", family="monospace")

    # Save the figure as an image
    plt.savefig("character.png", bbox_inches="tight", pad_inches=0)
    plt.close(fig)

    # Load the image into numpy array
    image = Image.open("character.png").convert("RGB")

    # Resize the image to 256x256 pixels
    image = image.resize((64, 64))

    # Convert the image to a numpy array
    data = np.asarray(image)

    # Delete the temporary file
    os.remove("character.png")

    return data


def process_lasa_multi_data(datasets, window_size=12, stride=12):
    """
    Processes multiple LASA handwriting datasets by windowing the position data and generating task conditioning images.
    Args:
        datasets (list): List of dataset names to be processed.
        window_size (int, optional): Size of the window for segmenting the position data. Defaults to 12.
        stride (int, optional): Stride for the windowing function. Defaults to 12.
    Returns:
        tuple: A tuple containing three JAX arrays:
            - inputs_data (jax.numpy.ndarray): The input data array of shape (total_segments, windows * points, 2).
            - targets_data (jax.numpy.ndarray): The target data array of shape
              (total_segments, windows * points, 2).
            - imgs_data (jax.numpy.ndarray): The image data array of
              shape (total_segments, 7, image_height, image_width).
    """
    inputs_data = []
    targets_data = []
    imgs_data = []
    for dataset in datasets:
        # process raw demonstration data
        raw_data = lasa.DataSet.__getattr__(dataset)
        demos = raw_data.demos

        pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
        pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        pos_data_windowed = lasa_window_fn(pos, window_size, stride)
        pos_data_windowed = e.rearrange(
            pos_data_windowed, "demonstrations points segments windows -> demonstrations segments (windows points)"
        )
        inputs = pos_data_windowed[:, :, :2]
        targets = pos_data_windowed[:, :, 2:]

        # generate image of character for task conditioning
        char_img = generate_character_image(dataset[0])
        imgs = jnp.repeat(jnp.expand_dims(char_img, axis=0), repeats=7, axis=0)

        inputs_data.append(inputs)
        targets_data.append(targets)
        imgs_data.append(imgs)

    # convert datasets to jax arrays
    inputs_data = jnp.asarray(inputs_data)
    targets_data = jnp.asarray(targets_data)
    imgs_data = jnp.asarray(imgs_data)

    # collapse different tasks into one dataset
    inputs_data = jnp.reshape(inputs_data, (-1, *inputs_data.shape[2:]))
    targets_data = jnp.reshape(targets_data, (-1, *targets_data.shape[2:]))
    imgs_data = jnp.reshape(imgs_data, (-1, *imgs_data.shape[2:]))

    return inputs_data, targets_data, imgs_data


def process_lasa_multi_position_delta_data(datasets, window_size=12, stride=12):
    """
    Processes LASA handwriting datasets to generate inputs, targets, and images for model training.
    Args:
        datasets (list): List of dataset names to process.
        window_size (int, optional): Size of the window for segmenting the data. Defaults to 12.
        stride (int, optional): Stride for the window segmentation. Defaults to 12.
    Returns:
        tuple: A tuple containing:
            - inputs_data (jax.numpy.ndarray): Array of input data with shape (num_samples, segments, 2).
            - targets_data (jax.numpy.ndarray): Array of target data with shape (num_samples, segments, 2).
            - imgs_data (jax.numpy.ndarray): Array of image data with shape (num_samples, 7, height, width).
    """
    inputs_data = []
    targets_data = []
    imgs_data = []
    for dataset in datasets:
        # process raw demonstration data
        raw_data = lasa.DataSet.__getattr__(dataset)
        demos = raw_data.demos

        pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
        pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        pos_data_windowed = lasa_window_fn(pos, window_size, stride)
        pos_data_windowed = e.rearrange(
            pos_data_windowed, "demonstrations points segments windows -> demonstrations segments (windows points)"
        )
        inputs = pos_data_windowed[:, :, :2]
        targets = pos_data_windowed[:, :, 4:] - pos_data_windowed[:, :, 2:-2]

        # generate image of character for task conditioning
        char_img = generate_character_image(dataset[0])
        imgs = jnp.repeat(jnp.expand_dims(char_img, axis=0), repeats=7, axis=0)

        inputs_data.append(inputs)
        targets_data.append(targets)
        imgs_data.append(imgs)

    # convert datasets to jax arrays
    inputs_data = jnp.asarray(inputs_data)
    targets_data = jnp.asarray(targets_data)
    imgs_data = jnp.asarray(imgs_data)

    # collapse different tasks into one dataset
    inputs_data = jnp.reshape(inputs_data, (-1, *inputs_data.shape[2:]))
    targets_data = jnp.reshape(targets_data, (-1, *targets_data.shape[2:]))
    imgs_data = jnp.reshape(imgs_data, (-1, *imgs_data.shape[2:]))

    return inputs_data, targets_data, imgs_data


def process_lasa_multi_velocity_data(datasets, window_size=12, stride=12):
    """
    Processes multiple LASA handwriting datasets to generate inputs, targets, and images for training.
    Args:
        datasets (list): List of dataset names to process.
        window_size (int, optional): Size of the window for segmenting the data. Defaults to 12.
        stride (int, optional): Stride for the window segmentation. Defaults to 12.
    Returns:
        tuple: A tuple containing:
            - inputs_data (jax.numpy.ndarray): Array of input data with
              shape (num_samples, segments, window_size * 2).
            - targets_data (jax.numpy.ndarray): Array of target data
              with shape (num_samples, segments, window_size - 2).
            - imgs_data (jax.numpy.ndarray): Array of images with shape (num_samples, 7, image_height, image_width).
    """
    inputs_data = []
    targets_data = []
    imgs_data = []
    for dataset in datasets:
        # process raw demonstration data
        raw_data = lasa.DataSet.__getattr__(dataset)
        demos = raw_data.demos

        pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
        vel = jnp.hstack([demos[i].vel for i in range(len(demos))])
        pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        pos_data_windowed = lasa_window_fn(pos, window_size, stride)
        pos_data_windowed = e.rearrange(
            pos_data_windowed, "demonstrations points segments windows -> demonstrations segments (windows points)"
        )
        vel = e.rearrange(vel, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        vel_data_windowed = lasa_window_fn(vel, window_size, stride)
        vel_data_windowed = e.rearrange(
            vel_data_windowed, "demonstrations points segments windows -> demonstrations segments (windows points)"
        )

        inputs = jnp.concatenate(
            [pos_data_windowed[:, :, :2], vel_data_windowed[:, :, :2]], axis=-1
        )  # curent position and velocity
        targets = vel_data_windowed[:, :, 2:]  # target velocity

        # generate image of character for task conditioning
        char_img = generate_character_image(dataset[0])
        imgs = jnp.repeat(jnp.expand_dims(char_img, axis=0), repeats=7, axis=0)

        inputs_data.append(inputs)
        targets_data.append(targets)
        imgs_data.append(imgs)

    # convert datasets to jax arrays
    inputs_data = jnp.asarray(inputs_data)
    targets_data = jnp.asarray(targets_data)
    imgs_data = jnp.asarray(imgs_data)

    # collapse different tasks into one dataset
    inputs_data = jnp.reshape(inputs_data, (-1, *inputs_data.shape[2:]))
    targets_data = jnp.reshape(targets_data, (-1, *targets_data.shape[2:]))
    imgs_data = jnp.reshape(imgs_data, (-1, *imgs_data.shape[2:]))

    return inputs_data, targets_data, imgs_data
