#!/usr/bin/env python3
"""Generate a cooccurrence tensor containing the number of occurrences of an observation given the preceding observation and action."""

import ast
import os
import pickle
import sys
import numpy as np
from load_data import discretize_to_1d, encode, parallel_generator

# The number of values pixels are discretized to
DISCRETE_VALUES = 4

# Start loading the log file that is passed
path = os.path.expanduser(sys.argv[1])
matrix_file = open(path)

def load_line(prefix):
    """Given a prefix, read from the matrix file and get the stuff after that prefix, skipping all of the lines prior to that line."""
    stuff = None
    # Iterate until it is found, then break and return it (evaluated as a literal)
    while True:
        line = matrix_file.readline()
        if prefix not in line:
            continue
        stuff_str = line.split(prefix)[1].strip()
        stuff = ast.literal_eval(stuff_str)
        break
    return stuff

# Load the array of color bins
color_bins = np.array(load_line('Color bins:'))
# Load the list of image scalars used for indexing
print('Loading image scalars')
image_scalars = load_line('Image scalars:')
image_scalars_dict = {scalar: i for i, scalar in enumerate(image_scalars)}
matrix_file.close()
print('Finished loading image scalars')

# Use a dictionary with tuple keys as a sparse cooccurrence matrix, using the format (O_last, A, O_next)
cooccurrence_counts = {}

data_generator = parallel_generator(open(sys.argv[2]), return_sequences=False, random_and_repeat=False, action_space=int(sys.argv[3]), return_filenames=True, use_downscaled_files=True)
for file_index, ((images, actions), _, filenames) in enumerate(data_generator):
    # Remove the sequences axis and replace the one-hot actions with integers
    images = images[0]
    actions = actions[0].argmax(1)
    # Convert the images to indices (using the list generated previously)
    image_indices = []
    for values in discretize_to_1d(images, color_bins, DISCRETE_VALUES):
        scalar = encode(values, DISCRETE_VALUES)
        index = image_scalars_dict[scalar]
        image_indices.append(index)

    # Increment the cooccurrence matrix for each image-action-image tuple
    for transition_tuple in zip(image_indices[:-1], actions[:-1], image_indices[1:]):
        if transition_tuple not in cooccurrence_counts:
            cooccurrence_counts[transition_tuple] = 0
        cooccurrence_counts[transition_tuple] += 1

# Save the cooccurrence matrix as a Pickle file
output_path = sys.argv[4]
with open(output_path, 'wb') as f:
    pickle.dump(cooccurrence_counts, f)
print('Saved cooccurrence matrix at', output_path)
