import json
import os
import pickle
from pathlib import Path
from PIL import Image
import sys

from Config import get_data_dir, id_from_path
from Data import *

# Handle system arguments
name = sys.argv[1]
mode = name.split('-')[0]

# General config
max_attempts = 10000
verbose = False

# Define the Experimental Configuration
num_buckets = None
if mode == 'complex':
    num_features = np.random.randint(1, 4)
    num_options = np.random.randint(5, 8)
    num_options -= num_features
    blindspot_sizes = list(np.random.randint(4, 7, size = np.random.randint(1, 4)))
    blindspot_sizes.sort()

    features = [Background(), Square()]
    features.extend(random.sample([Rectangle(), Circle(), Text()], num_features))
elif mode == 'dc': # Dataset Complexity
    num_features = np.random.randint(1, 4)
    num_options = np.random.randint(3, 7)
    num_options -= num_features
    blindspot_sizes = [3]

    features = [Background(), Square()]
    features.extend(random.sample([Rectangle(), Circle(), Text()], num_features))
else:
    print('Unrecognized "mode"')
    sys.exit()
    
if num_buckets is None:
    num_buckets = 2**(num_options + num_features + 1)

# Initialize the dataset
d = Dataset(features)

# Enable some of the features of those Features
for i in range(num_options):
    d.enable()

# Add the meta features
d.set_meta_features(add_meta_features, compute_meta_features)

# Generate a set of irreducible blindspots
blindspots = []
i = 0
while i < len(blindspot_sizes):
    d.set_blindspots(blindspots)
    loop = True
    attempt = 0
    while loop:
        # Add features to the candidate blindspot
        candidate = d.get_default_blindspot()
        for j in range(blindspot_sizes[i]):
            d.add_feature(candidate)
        # Roll the feature values
        candidate = d.realize_blindspot(candidate)
        # Check if this new blindspot is ok to keep         
        loop = not d.check_validity(candidate)
        # Check if we need to reset
        if loop:
            attempt += 1
            if attempt == max_attempts:
                if verbose:
                    print('Resetting')
                blindspots = []
                i = 0
                loop = False
                attempt = -1
    if attempt != -1:
        if verbose:
            print(candidate)
        blindspots.append(candidate)
        i += 1

d.set_blindspots(blindspots)

# Show the finished dataset
if verbose:
    d.print()

# Setup
base_dir = '{}/{}'.format(get_data_dir(), name)
os.system('rm -rf {}'.format(base_dir))
Path(base_dir).mkdir(parents = True, exist_ok = True)

# Save this configuration
with open('{}/dataset.pkl'.format(base_dir), 'wb') as f:
    pickle.dump(d, f)

# Create the class maps 
names = ['object']
name2index = {}
for i, v in enumerate(names):
    name2index[v] = i
index2name = list(name2index)

with open('{}/maps.json'.format(base_dir), 'w') as f:
    json.dump([name2index, index2name], f)

# Process the splits
num_images = {'train': 400 * num_buckets, 'val': 50 * num_buckets, 'test': 50 * num_buckets}
for mode in ['test', 'val', 'train']:
    mode_dir = '{}/{}'.format(base_dir, mode)
    os.system('mkdir {}'.format(mode_dir))
    if verbose:
        print('Generating data in: ', mode_dir)
    
    image_dir = '{}/images'.format(mode_dir)
    os.system('mkdir {}'.format(image_dir))
    images = {}
    positive_examples = []
    for i in range(num_images[mode]):
        img_id = str(i)
        img_path = '{}/{}.jpg'.format(image_dir, img_id)
        
        img_numpy, metadata, bboxes = d.generate()
        img_pill = Image.fromarray(img_numpy)
        img_pill.save(img_path)
        
        if mode in ['val', 'train']:
            label, contained = d.get_blindspot_label(metadata)
        else:
            label, contained = d.get_true_label(metadata)
        label = [label]
            
        if label == [1]:
            positive_examples.append(img_id)
            
        images[img_id] = {'file': img_path, 'label': label, 'metadata': metadata, 'contained': contained}
        
    with open('{}/images.json'.format(mode_dir), 'w') as f:
        json.dump(images, f)
        
    name2ids = {'object': positive_examples}
    with open('{}/name2ids.json'.format(mode_dir), 'w') as f:
        json.dump(name2ids, f) 
        