from random import sample
import numpy as np

''' An instance is characterized by 
    seed: random seed
    m: number of items
    data generative model (DGM): iid, mildly perturbed, markov
    if DGM == iid, then s = distribution over [m]
    item arrivals: an array with length == T '''

seed = 2022
np.random.seed(2022)
n, m = 100, 300
T = 200 * n
repeat = 10

data_generative_model = 'Mildly Corrupted'
s = np.random.uniform(size=m)
s = s / np.sum(s)

def sample_all_arrivals_iid_finite_item_set(s, T=1000, repeat=10):
    ''' given item set of size m, return a sample path of item arrivals '''
    m = len(s)
    return (
        np.random.choice(m, size=(repeat,T)) if s is None else np.random.choice(m, size=(repeat,T), p=s)
    )

# all_sample_paths = sample_all_arrivals_mild_corrupt_finite_item_set(s, T=n*200, repeat=10)
def sample_all_arrivals_mild_corrupt_finite_item_set(s, T=1000, repeat=10):
    ''' given item set of size m, return some sample paths of item arrivals 
        here, we randomly generate (and fix) a sequence of distributions '''
    m = len(s)
    # nonstationary example
    all_sample_paths = []
    # items_all_t = []
    for t in range(T):
        # t=0
        # create distribution perturbation whose 1-norm is 1/(t+1)
        perturbation = np.random.uniform(size=m)
        num_neg = np.random.randint(m+1)
        perturbation[np.random.choice(m, num_neg, replace=False)] *= -1
        perturbation = perturbation - np.mean(perturbation)
        perturbation = (perturbation / np.linalg.norm(perturbation,1)) / (t+1)
        curr_distr = np.maximum(s + perturbation, 0)
        curr_distr = curr_distr / np.sum(curr_distr)
        # np.sum(perturbation), np.linalg.norm(perturbation, 1) - 1/(t+1)
        # curr_distr = np.maximum(s + perturbation, 0) + 0.05/(t+1)
        # curr_distr = curr_distr / curr_distr.sum()
        j_all_repeats = np.random.choice(m, size=repeat, p=curr_distr)
        all_sample_paths.append(j_all_repeats)
    return np.array(all_sample_paths).T # dim = (num_repeat, T)

def sample_all_arrivals_markov_finite_item_set(s, T=1000, repeat=10):
    ''' given item set of size m, return a sample path of item arrivals '''
    m = len(s)
    # transition matrix
    P = np.random.uniform(size=(m,m))
    # sum of P[i,j] for all j == 1
    P[np.arange(m), np.arange(m)] = 0
    P = (P.T / np.sum(P, axis=1)).T

    # nonstationary example
    all_sample_paths = []    
    for rep in range(repeat):
        sample_path = [ np.random.choice(m, p=s) ]
        for t in range(1,T): # markov chain transition
            sample_path.append( np.random.choice(m, p=P[sample_path[-1]]) )
    # append to the list of all sample paths    
        all_sample_paths.append(sample_path)
    return np.array(all_sample_paths)

def sample_all_arrivals_periodic_finite_item_set(s, T=1000, repeat=10):
    ''' s is not used '''
    plen = max(int(T//500)+1, 100)
    n_periods = int(T//plen)+1
    distributions = np.random.exponential(size=(plen, m))
    distributions = (distributions.T / distributions.sum(axis=1)).T
    all_sample_paths = []
    for rep in range(repeat):
        sample_path = []
        for period in range(n_periods):
            curr_period_items = []
            for k in range(plen):
                curr_period_items.append(np.random.choice(m, p=distributions[k]))
            # random shuffle
            np.random.shuffle(curr_period_items)
            sample_path.extend(curr_period_items)
        all_sample_paths.append( sample_path[:T] )
    return np.array(all_sample_paths)

# generate all mild corruption sample_paths
all_sample_paths_iid = sample_all_arrivals_iid_finite_item_set(s, T=200*n, repeat=repeat)
all_sample_paths_mild_corrupt = sample_all_arrivals_mild_corrupt_finite_item_set(s, T=200*n, repeat=repeat)
all_sample_paths_markov = sample_all_arrivals_markov_finite_item_set(s, T=200*n, repeat=repeat)
all_sample_paths_periodic = sample_all_arrivals_periodic_finite_item_set(s, T=200*n, repeat=10)

for data in (
    all_sample_paths_iid, all_sample_paths_markov, all_sample_paths_mild_corrupt, all_sample_paths_periodic
):
    print(data.shape)

# save them to file
np.savez_compressed(
    'instances/all_sample_paths.npz', 
    all_sample_paths_iid, 
    all_sample_paths_mild_corrupt, 
    all_sample_paths_markov, 
    all_sample_paths_periodic
)

# import numpy as np
# (
#     all_sample_paths_iid, 
#     all_sample_paths_mild_corrupt, 
#     all_sample_paths_markov, 
#     all_sample_paths_periodic
# ) = np.load('all_sample_paths.npz').values()