from utils.quadratic_model import QuadModel
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

# Sample from simplex
def unif_simplex_measure(d):
    pre_normalized = np.random.exponential(1, size = (d))
    return pre_normalized / sum(pre_normalized)

# Generate samples from simplex and \calK, where \calK = [0,1]^d in this case
def generate_samples(samples, d, lower_bound=0.0, upper_bound=1.0, oversample_zero=False, oversample_ratio=0.2):
    data = np.zeros((samples, 2 * d))
    for i in range(samples):
        data[i, :d] = np.random.uniform(lower_bound, upper_bound, d)
        data[i, d:] = unif_simplex_measure(d)

    if oversample_zero:
        for i in range(int(samples * oversample_ratio)):
            data = np.vstack([data, np.concatenate([np.zeros(d), unif_simplex_measure(d)])])

    return data

# Generate samples trajectories using Picard iteration
def generate_flows(init_data, model):
    samples = init_data.shape[0]
    flow_data = np.zeros((samples, model.d, model.Nt + 1))

    for i, sample in enumerate(tqdm(init_data)):
        theta = sample[:model.d]
        mu_0 = sample[model.d:]
        u_T = np.zeros(model.d)
        _, u = model.picard_iteration(mu_0, u_T, theta=theta)
        flow_data[i, :] = u.T

    return flow_data

if __name__ == '__main__':
    d = 3
    Nt = 200
    cs_model = QuadModel(name=f'gen_data_d={d}', d=d, Nt=Nt)

    samples = 4000
    lower_bound = 0.0
    upper_bound = 1.0

    init_data = generate_samples(samples, d, lower_bound, upper_bound)
    print(f'Sampled {samples} initializations, solving MFG systems')

    flow_data = generate_flows(init_data, cs_model)

    # Save data
    np.savez(f'data/quad_data_d={d}', init_data, flow_data)
