"""
generate data from a time dependent ou process
"""

import numpy as np
import matplotlib.pyplot as plt
import sdeint
import torch
import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--train', type=int, choices=[0, 1], default=1)
args = parser.parse_args()




train = args.train
mu = 0.2
theta = 0.1
sigma = 0.6
phi = 0.15
tstart = 0
tend = 10




if train:
    ninitial = 200
    nsamples = 45
    name = 'train'
    ntimes = 101
    np.random.seed(982846947) #any seed was typed in randomly
else:
    ninitial = 20
    nsamples = 45
    name = 'val'
    ntimes = 101
    np.random.seed(535617672) #any seed was typed in randomly

times = np.linspace(tstart, tend, ntimes)


def f(y, t):
    return (mu*t - theta*y)


def g(y, t):
    return sigma + phi*t


def solve_sde(y0):
    return sdeint.stratint(f, g, y0, times)


# plot to see samples
for i in range(50):
    y0 = 6*np.random.rand()-3
    y = solve_sde(y0)
    plt.plot(times, y, c='b', alpha=0.2)




ys = np.empty((ninitial, nsamples, ntimes, 1))


for i in range(ninitial):
    y0 = 6*np.random.rand()-3
    for j in range(nsamples):
        y = solve_sde(y0)
        ys[i][j] = y
        
    
ys = torch.tensor(ys).float()

times = torch.tensor(times).float().unsqueeze(-1)
torch.save(ys, 'ou_y_'+name+'.pt')
torch.save(times, 'ou_t_'+name+'.pt')



