# %%
"""
Generates pretraining tasks to teach models XOR

This pretraining task implements a difference rule
Let W, X, Y, Z be independent elements that correspond to square, triangle, cross, circle.
Then
diff(W, X) = True
diff(W, W) = False
diff(W, Z) = True
diff(Z, Z) = False
etc.

Also:
diff(W, X, Y) = True
diff(W, W, Y) = False
diff(W, W, W) = False
etc.

And:
diff(W, X, Y, Z) = True
diff(W, X, X, Z) = False
diff(W, W, W, W) = False
"""
from random import choice
import itertools
import numpy as np
import argparse
import h5py
import os

parser = argparse.ArgumentParser('./main.py', description='Generate pretraining tasks to teach NNs XOR reasoning')
parser.add_argument('--num_elements', type=int, default=2, help='Number of elements per puzzle (min=2; max=4)')
parser.add_argument('--filename', type=str, default='../data/pretraining/xor_pretraining_samples', help='prefix of filename to save to')


def run(args):
    """
    Generates trials for pretraining XOR

    This code will generate all possible trials/arrangements for XOR (or Different) 
    stimulus sets for a particular number of inputs (2-4)

    Note that many of the generated inputs will be identical, w.r.t. to permutation 
    of where the stimuli are located on the 4x4 grid. However, these variations can
    still be of potential use for generalizing the XOR rule.

    How to run:
    python generate_pretraining_xor.py --num_elements 2
    python generate_pretraining_xor.py --num_elements 3
    python generate_pretraining_xor.py --num_elements 4
    
    will generate all possible variations of the pretraining task.
    Inputs/outputs will be stored as HDF5 files in the default location:
    ../data/pretraining/'

    Args:
        num_elements (int): number of input elements per trial (range 2-4)
        save_file (str):    filename to save output 
    
    """
    args
    num_elements = args.num_elements
    filename = args.filename

    print('Generating pretraining trials with', num_elements, 'per input')
    

    # dimensions of puzzle/grid
    grid_x = 4
    grid_y = 4
    # number of actual symbols
    num_symbols = 4 # circle, cross, square, triangle
    # number of input elements (including '?')
    num_input_elements = 5



    x_indices = range(grid_x)
    y_indices = range(grid_y)

    # create a list of all possible coordinates to sample from
    print('Creating all possible coordinates in input space')
    all_coords = []
    for x in x_indices: 
        for y in y_indices: 
            all_coords.append((x, y))

    # Now create a list of all n permute num_elements coordinates
    coord_permutations = list(itertools.permutations(all_coords,num_elements))

    # Now create a list of all n choose num_elements symbols
    # here we want combinations, bc if we had permutations we would sample same inputs
    symbols = range(num_symbols)
    symbol_combos = list(itertools.combinations_with_replacement(symbols,num_elements))

    # Now create inputs and outputs
    print('Creating all inputs/outputs')
    inputs = []
    outputs = []
    for symbol_list in symbol_combos:
        truth_value = compute_xor(symbol_list)

        for coord in coord_permutations:
            # Define input-output matrices
            input_mat = np.zeros((grid_x, grid_y, num_symbols))
            
            # run xor computation, since the truth value doesn't depend on coordinates
            

            for i in range(num_elements):
                x, y = coord[i]
                sym = symbol_list[i]
                input_mat[x,y,sym] = 1

            inputs.append(input_mat)
            outputs.append(truth_value)

    inputs = np.asarray(inputs)
    outputs = np.asarray(outputs)

    print('... saving file to', filename + '_' + str(num_elements) + '_elements_per_input.h5')
    if not os.path.exists('../data/pretraining/'): os.makedirs('../data/pretraining/')
    h5f = h5py.File(filename + '_' + str(num_elements) + '_elements_per_input.h5','a')
    try:
        # try block, just in case file already exists
        h5f.create_dataset('inputs', data=inputs)
        h5f.create_dataset('outputs', data=outputs)
    except:
        del h5f['inputs'], h5f['outputs']
    h5f.close()

    print('done')



        
def compute_xor(arr):
    """
    computes XOR given a tuple of at least len(arr)>2

    Args:
        arr (tuple): Tuple of at least 2 elements
    
    Returns:
        statement (bool)
    """
    # XOR requires uniqueness of every element in the array. 
    # Iterate through list to ensure no two elements are the same
    statement = True # assume it's true
    indices = range(len(arr))
    for idx1 in indices:
        for idx2 in indices:
            if idx1 == idx2:
                continue
            # no need to check every pair
            if idx1 > idx2: 
                continue

            if arr[idx1] == arr[idx2]:
                statement = False
                return statement

    return statement
        
if __name__ == '__main__':
    args = parser.parse_args()
    run(args)
