from importlib.metadata import distribution
import numpy as np

seed = 2022
np.random.seed(seed)
n, m = 100, 300
T = 200 * n
repeat = 10

s = np.random.uniform(size=m) # fixed distribution
s = s / np.sum(s)

def sample_all_arrivals_iid_finite_item_set(s, T=1000, repeat=10):
    ''' given item set of size m, return a sample path of item arrivals '''
    m = len(s)
    return (
        np.random.choice(m, size=(repeat,T)) if s is None else np.random.choice(m, size=(repeat,T), p=s)
    )

def sample_all_arrivals_periodic_finite_item_set(
    s, T=T, repeat=repeat, plen=5
):
    ''' note that the reference distribution here is not exactly s '''
    n_periods = int(T//plen)+1
    # generate plen distributions that cycle over chunks of items
    distributions = np.zeros(shape=(plen, m))
    chunksize = m//plen
    for k in range(plen): 
        start, end = chunksize*k, (chunksize * (k+1) if k < plen-1 else m)
        distributions[k][ start:end ] = s[ start:end ]
        distributions[k] /= np.sum(distributions[k])
    all_sample_paths = []

    for rep in range(repeat):
        sample_path = []
        for period in range(n_periods):
            curr_period_items = []
            for k in range(plen):
                curr_period_items.append(np.random.choice(m, p=distributions[k]))
            sample_path.extend(curr_period_items)
        all_sample_paths.append( sample_path[:T] )

    return np.array(all_sample_paths)

from collections import defaultdict
sample_paths_periodic = defaultdict(list)
all_plen = (5, 10, 20, 50, 100, 200)
for plen in (5, 10, 20, 50, 100, 200):
    sample_paths_periodic[plen] = sample_all_arrivals_periodic_finite_item_set(
        s=s, T=T, repeat=10, plen=plen
    )
sample_paths_periodic['iid'] = sample_all_arrivals_iid_finite_item_set(
    s=s, T=T, repeat=repeat
)

# save them to file

for plen in ('iid',) + all_plen:
    for sd in range(repeat):
        np.savez_compressed(
            'instances/all_sample_paths_periodic_{}_{}.npz'.format(plen, sd), 
            sample_paths_periodic[plen][sd]
        )

# np.savez_compressed(
#     'instances/all_sample_paths_periodic.npz', 
#     *(sample_paths_periodic[plen] for plen in ('iid',) + all_plen)
# )


# sample_paths_periodic[plen][0][:100]


# import numpy as np
# iterable_of_sample_paths = np.load('instances/all_sample_paths_periodic.npz').values()
# loaded_dict = {
#     plen: sp for (plen, sp) in zip(
#         all_plen, iterable_of_sample_paths
#     )
# }