""" Run the script by:
    python TD/gen_data.py -a -1.0 -T 8 -B 4096 -runs 50
"""

import os
import gc
import argparse
import numpy as np
from tqdm import trange


def gen_CT_data(a, sigma, h0, N0, M, rng_obj, dtype=np.float16, verbose=False):
    """generate M trajectories, each with N0 samples as a proxy to CT system

    adding up the noise sequence w(0),w(1),...,w(N0-1) gives N+1 samples
    x(0),x(1),...,x(N0) where x(0)=0
    """
    if a != 0:
        noise_std = sigma * np.sqrt((np.exp(2 * a * h0) - 1) / (2 * a))
    else:
        noise_std = sigma * np.sqrt(h0)
    # w = np.random.normal(scale=noise_std, size=(N0, M)).astype(dtype)
    w = rng_obj.standard_normal(size=(N0, M), dtype=np.float32) * noise_std
    x = np.zeros((N0 + 1, M), dtype=np.float32)
    if a != 0:
        a_eq = np.exp(a * h0)
        if verbose:
            for k in trange(
                1, N0 + 1, desc="generate x(t)", leave=False
            ):  # loop over all time steps t=0,...,N0-1
                x[k, :] = a_eq * x[k - 1, :] + w[k - 1, :]
        else:
            for k in range(1, N0):
                x[k, :] = a_eq * x[k - 1, :] + w[k - 1, :]
    else:  # a=0, just summing w(t)
        x = np.cumsum(w, axis=0)
    return x.astype(dtype)


parser = argparse.ArgumentParser()
parser.add_argument("-a", type=float, required=True, help="parameter a<0")
parser.add_argument("-T", type=float, required=True, help="horizon T")
parser.add_argument("-B", type=int, required=True, help="data budget B")
parser.add_argument("-runs", type=int, default=50, help="number of runs")
parser.add_argument(
    "-dtype", type=str, default="float16",
    help="data type (float16 or float32)"
)
args = parser.parse_args()

a = args.a
T = args.T
N0 = 2**16
h0 = T / N0  # 2^-16 * T
B = args.B
num_runs = args.runs
imax = 16
M0 = int(B * (2**imax) / N0)  # longest number of trajectories that we need
sigma = 1
seed = 0  # base of seed, for reproducibility
# Determine the data type
dtype = np.float16 if args.dtype == "float16" else np.float32
print(
    f"Hyperparams: T={T},B={B},a={a},runs={num_runs},h0={h0},N0={N0},M0={M0},sigma={sigma},dtype={dtype}"
)

path = "data/share"
if not os.path.exists(path):
    print(f"Creating {path} since it did not exist")
    os.makedirs(path)

for j in trange(
    num_runs, desc="runs"
):  # outer loop of many runs to approximate the E[(Vh-V)^2]
    # store the data here for debugging
    fname = os.path.join(
        path, f"x_run_{j}_a_{a}_T_{T}_B_{B}_seed_{seed}_{args.dtype}.npy"
    )

    if os.path.exists(fname):
        print(f"{fname} already exists")
    else:  # if data not found, generate and save
        if "x" in locals():
            # release memory of last iteration
            del x
            gc.collect()
        # np.random.seed(seed + j)  # different seed for each run
        rng_obj = np.random.default_rng(seed + j)
        x = gen_CT_data(
            a, sigma, h0, N0, M0, rng_obj, dtype=dtype, verbose=True
        )  # shape (N0+1, M0)
        with open(fname, "wb") as f:
            np.save(f, x)
