"""
Generate a fixed set of 4-layer architectures and randomly split into train/test sets.

For ManiSkill environments with layer sizes [16, 32, 64]:
- 4-layer: total 81 architectures
- Filter: remove architectures where 3 or 4 layers are 16 (too narrow)
- After filtering: 72 architectures (81 - 9 = 72)
  - Removed: 1 arch with 4×16 + 8 archs with 3×16 = 9 total
- Split: train 64, test 8
"""

import json
import os
import random


def generate_4layer_architectures(allowable_layers, filter_narrow=True):
    """Generate all possible 4-layer architectures.

    Args:
        allowable_layers: List of allowed layer sizes (e.g., [16, 32, 64])
        filter_narrow: If True, remove architectures where 3 or 4 layers are 16 (too narrow)
    """
    list_of_allowable_layers = list(allowable_layers)

    architectures = []

    # Generate 4-layer architectures (all layers from allowable_layers)
    for first_layer in list_of_allowable_layers:
        for second_layer in list_of_allowable_layers:
            for third_layer in list_of_allowable_layers:
                for fourth_layer in list_of_allowable_layers:
                    arch = [first_layer, second_layer, third_layer, fourth_layer]

                    # Filter out architectures where 3 or 4 layers are 16 (too narrow)
                    if filter_narrow:
                        num_16_layers = sum(1 for layer in arch if layer == 16)
                        if num_16_layers >= 3:
                            continue

                    architectures.append(arch)

    return architectures



def split_architectures(architectures, num_test=8):
    """Split architectures into train/test sets.

    Args:
        architectures: List of architectures to split
        num_test: Number of architectures for test set (default: 8)
    """
    # Randomly shuffle and split
    shuffled_archs = architectures.copy()
    random.shuffle(shuffled_archs)

    test_archs = shuffled_archs[:num_test]
    train_archs = shuffled_archs[num_test:]

    return train_archs, test_archs


def save_architectures_json(architectures, filename, output_dir):
    """Save architectures to JSON file"""
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Create the full file path
    filepath = os.path.join(output_dir, filename)

    # Sort by sum of layer sizes for consistency
    sorted_archs = sorted(architectures, key=lambda x: (len(x), sum(x)))

    with open(filepath, 'w') as f:
        json.dump(sorted_archs, f, indent=2)

    print(f"Saved {len(sorted_archs)} architectures to {filepath}")

# Output directory (relative to this script)
script_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = script_dir

print(f"Generating 4-layer architectures with layer sizes: [16, 32, 64]")
print(f"Filter: remove architectures where 3 or 4 layers are 16 (too narrow)")
print(f"Test set size: 8 architectures")
print(f"Random seed: 42")
print(f"Output directory: {output_dir}")
print()

random.seed(42)

# Generate all 4-layer architectures (with filtering)
all_architectures = generate_4layer_architectures([16, 32, 64], filter_narrow=True)
print(f"Generated {len(all_architectures)} total 4-layer architectures (after filtering)")

# Split into train/test (8 test architectures)
train_archs, test_archs = split_architectures(all_architectures, num_test=8)

print(f"Train architectures: {len(train_archs)}")
print(f"Test architectures: {len(test_archs)}")
print()

# Save to JSON files
save_architectures_json(train_archs, "train_arch_4layer.json", output_dir)
save_architectures_json(test_archs, "test_arch_4layer.json", output_dir)
save_architectures_json(all_architectures, "all_arch_4layer.json", output_dir)

print()
print("Architecture generation complete!")
print(f"Files saved to: {output_dir}")
