#!/usr/bin/env python3
"""Loading code for the human Atari data."""

import itertools
import tarfile
from pathlib import Path
import cv2
import imageio
import numpy as np

def human_data_generator(game, resize_to_84=False, grayscale=True, return_paths=False, cycle=True):
    """Load the human data from the Atari-HEAD dataset."""
    folder = Path.home() / 'human_atari' / game
    # Get a list of the folders of images and the images within
    image_folders = [item for item in folder.iterdir() if item.is_dir()]
    image_path_dict = {image_folder: list(image_folder.iterdir()) for image_folder in image_folders}
    text_files = folder.glob('*.txt')
    trials = []
    for path in text_files:
        csv_file = open(path)
        trial_raw_data = []
        # Skip the first line which is the heading
        next(csv_file)
        for line in csv_file:
            line_data = []
            for raw_value in line.split(','):
                # Convert numeric fields to int
                try:
                    value = int(raw_value)
                except ValueError:
                    value = raw_value
                line_data.append(value)
            # Format is [frame_id, episode_id, score, duration(ms), unclipped_reward, action, gaze_positions]
            frame_id = line_data[0]
            reward = line_data[4]
            action = line_data[5]
            # If the timestep is invalid for whatever reason, skip it
            if 'null' in [frame_id, reward, action]:
                continue
            # Load the image corresponding to this timestep's frame_id
            image_folder_id = int(frame_id.split('_')[1])
            image_folder_path = [path for path in image_folders if path.name != 'highscore' and int(path.name.split('_')[2]) == image_folder_id][0]
            image_path = (image_folder_path / frame_id).with_suffix('.png')
            image = imageio.imread(image_path)
            # Convert the image to grayscale
            if grayscale:
                image = (image[:, :, 0] * 0.299) + (image[:, :, 1] * 0.587) + (image[:, :, 2] * 0.114)
            # Optionally rescale the image
            if resize_to_84:
                image = cv2.resize(image, (84, 84))
            # Add the full timestep data to the list
            trial_raw_data.append((image, reward, action))
        # Convert the trial tuples to NumPy arrays
        images = np.stack([image for (image, _, _) in trial_raw_data])
        rewards = np.array([reward for (_, reward, _) in trial_raw_data])
        actions = np.array([action for (_, _, action) in trial_raw_data])
        # Wrap in single-element lists because of main data loading code
        trial = [[images], [rewards], [actions]]
        if return_paths:
            trial.append(path)
        trials.append(trial)
        print('Loaded trial', path)
    # Cycle through all of the trials
    if cycle:
        return itertools.cycle(trials)
    else:
        return iter(trials)

# Run the loading code for testing if this script is executed directly
if __name__ == '__main__':
    human_data_generator('montezuma_revenge')
