from utils.cs_model import CSModel
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

# Sample from simplex
def unif_simplex_measure(d):
    pre_normalized = np.random.exponential(1, size = (d))
    return pre_normalized / sum(pre_normalized)

# Generate samples from simplex and \calK = [0,1]
def generate_samples(samples, d, lower_bound=0.0, upper_bound=1.0, oversample_zero=False, oversample_ratio=0.2):
    data = np.zeros((samples, d + 1))
    for k in range(samples):
        data[k, 0] = np.random.uniform(lower_bound, upper_bound)
        data[k, 1:] = unif_simplex_measure(d)

    if oversample_zero:
        for k in range(int(samples * oversample_ratio)):
            data = np.vstack([data, np.concatenate([[0.], unif_simplex_measure(d)])])

    return data

# Generate flows using Picard iteration
def generate_flows(init_data, model):
    samples = init_data.shape[0]
    flow_data = np.zeros((samples, model.NS, model.Nt + 1))

    for i, sample in enumerate(tqdm(init_data)):
        cost_param = sample[0]
        mu_0 = sample[1:]
        u_T = np.zeros(model.NS)
        _, u = model.picard_iteration(mu_0, u_T, cost_param=cost_param)
        flow_data[i, :] = u.T

    return flow_data

# Example for generating samples for cybersecurity model
if __name__ == '__main__':
    cs_model = CSModel('gen_data')

    samples = 2000
    d = cs_model.NS
    lower_bound = 0.0
    upper_bound = 10.0

    init_data = generate_samples(samples, d, lower_bound, upper_bound)
    print(f'Sampled {samples} initializations, solving MFG systems')

    flow_data = generate_flows(init_data, cs_model)

    # Save data
    np.savez('data/cs_data', init_data, flow_data)
