#%%
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from matplotlib.animation import FuncAnimation
from IPython.display import Image
import math

import diffrax
from idncflow import RK4
import jax.numpy as jnp

try:
    __IPYTHON__
    _in_ipython_session = True
except NameError:
    _in_ipython_session = False

## Parse the three arguments from the command line: "train", the foldername, and the seed

import argparse


if _in_ipython_session:
	# args = argparse.Namespace(split='train', savepath='tmp/', seed=42)
	args = argparse.Namespace(split='train', savepath="./tmp/", seed=2026, verbose=1)
else:
	parser = argparse.ArgumentParser(description='Gray-Scott dataset generation script.')
	parser.add_argument('--split', type=str, help='Generate "train", "test", "adapt", "adapt_test", or "adapt_huge" data', default='train', required=False)
	parser.add_argument('--savepath', type=str, help='Description of optional argument', default='tmp/', required=False)
	parser.add_argument('--seed',type=int, help='Seed to gnerate the data', default=2026, required=False)
	parser.add_argument('--verbose',type=int, help='Whether to print details or not ?', default=1, required=False)

	args = parser.parse_args()


split = args.split
assert split in ["train", "test", "adapt", "adapt_test", "adapt_huge"], "Split must be either 'train', 'test', 'adapt', 'adapt_test', 'adapt_huge'"

savepath = args.savepath
seed = args.seed

if args.verbose != 0:
  print("Running this script in ipython (Jupyter) session ?", _in_ipython_session)
  print('=== Parsed arguments to generate data ===')
  print(' Split:', split)
  print(' Savepath:', savepath)
  print(' Seed:', seed)
  print()


## Set numpy seed for reproducibility
np.random.seed(seed)


#%%

# Image(filename="tmp/coda_dataset.png")


#%%


#%%

## Define the ODE
def forced_oscillator(t, X, params):
    # dXdt = X / (1 - params) 
    dXdt = X / (1 + (params*X)**2)
    return dXdt

def get_init_cond():
    return np.random.uniform(1, 2, (1,))
    # return np.random.uniform(-1, 1, (1,))
    # return np.random.random(1) + 1.



if split == "train" or split=="test":
  # Training environments
  # environments = [(lambda t: 0.1*i*jnp.cos(i*t**2)) for i in range(0, 5)]
  # environments = [2, -2]

  environments = np.arange(1.25, 10, 1.25)
  # environments = np.arange(1.25, 10, 0.75)
  # environments = np.arange(1.25, 2.25, 0.1)


elif split == "adapt" or split=="adapt_test":
  environments = np.arange(1.25, 20, 1.25)
  # environments = np.arange(1.25, 20, 0.75)
  # environments = np.arange(1.25, 4.5, 0.1)


if split == "train":
  n_traj_per_env = 4     ## training
elif split == "test" or split == "adapt_test":
  n_traj_per_env = 32     ## testing
elif split == "adapt":
  n_traj_per_env = 12     ## adaptation


t_span = (0, 10)  # Shortened time span
n_steps_per_traj = math.ceil(t_span[-1]/0.5)

# Time span for simulation
t_eval = np.linspace(t_span[0], t_span[-1], n_steps_per_traj, endpoint=False)  # Fewer frames

data = np.zeros((len(environments), n_traj_per_env, n_steps_per_traj, 1))
max_seed = np.iinfo(np.int32).max

for j in range(n_traj_per_env):

    # initial_state = get_init_cond()
    np.random.seed(j if not split in ["test", "adapt_test"] else max_seed - j)
    initial_state = get_init_cond()

    for i, selected_params in enumerate(environments):

        # Solve the ODEs using SciPy's solve_ivp
        solution = solve_ivp(forced_oscillator, t_span, initial_state, args=(selected_params,), t_eval=t_eval, method='RK45')
        data[i, j, :, :] = solution.y.T

        # # use diffrax instead, with the DoPri5 integrator
        # solution = diffrax.diffeqsolve(diffrax.ODETerm(gray_scott),
        #                                diffrax.Tsit5(),
        #                                args=(selected_params),
        #                                t0=t_span[0],
        #                                t1=t_span[1],
        #                                dt0=1e-1,
        #                                y0=initial_state,
        #                                stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
        #                                saveat=diffrax.SaveAt(ts=t_eval),
        #                                max_steps=4096*1)
        # data[i, j, :, :] = solution.ys
        # print("Stats", solution.stats['num_steps'])

        # ys = RK4(forced_oscillator, 
        #             (t_eval[0], t_eval[-1]),
        #             initial_state,
        #             *(selected_params,), 
        #             t_eval=t_eval, 
        #             subdivisions=100)
        # data[i, j, :, :] = ys






# Save t_eval and the solution to a npz file
if split == "train":
  filename = savepath+'train_data.npz'
elif split == "test":
  filename = savepath+'test_data.npz'
elif split == "adapt":
  filename = savepath+'adapt_train.npz'
elif split == "adapt_test":
  filename = savepath+'adapt_test.npz'

## Check if nan or inf in data
if np.isnan(data).any() or np.isinf(data).any():
  print("NaN or Inf in data. Exiting without saving...")
else:
  np.savez(filename, t=t_eval, X=data)




#%%

if _in_ipython_session:

    fig, ax = plt.subplots(figsize=(10, 6))

    for env in range(len(environments)):
        x = data[env, 0, :, 0]
        ax.plot(t_eval, x, label=f'c={environments[env]:.2f}')

    ax.set_xlabel('Time')
    ax.set_ylabel('State')
    ax.set_title('Exemplification of a divergent power series')
    ax.legend()
    plt.show()


# %%