import os
import itertools
import random
from .util import *

def augmentation_samples(
        program_location = 's3://metagen-datasets/data/graph_v2/processed/',
        program_type: ProgramType = ProgramType.GRAPH,
        max_product_samples = 1000):
    seed_groups = []
    seed_dir = os.path.normpath(os.path.abspath(seed_dir))
    for dirpath, _, filenames in os.walk(seed_dir):
        seeds = [f for f in filenames if f.endswith('.json')]
        if len(seeds) > 0:
            seed_groups.append([os.path.join(dirpath, seed)[len(seed_dir):] for seed in seeds])
    grouped_samples = []
    samples = []
    for g1, g2 in itertools.product(seed_groups, repeat=2):
        pairs = [(s1, s2) for s1, s2 in itertools.product(g1, g2, repeat=1) if s1 != s2]
        if len(pairs) > max_product_samples:
            pairs = random.sample(pairs, max_product_samples)
        if len(pairs) > 0:
            grouped_samples.append(pairs)
        samples = samples + pairs
    return samples