#%%
# %load_ext autoreload
# %autoreload 2

## Do not preallocate GPU memory
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

from idncflow import *
# jax.config.update("jax_debug_nans", True)

## Execute jax on CPU
# jax.config.update("jax_platform_name", "cpu")

import warnings
warnings.filterwarnings("ignore")



#%%

seed = 2026

context_pool_size = 2
context_size = 2
init_lr = 1e-3
sched_factor = 0.25

nb_outer_steps = 5000
nb_inner_steps_max = 10
proximal_beta = 1e1
inner_tol_node = 2e-11
inner_tol_ctx = 1e-10

taylor_train_order = 12
taylor_adapt_order = 0

print_error_every = 100

train = True
# run_folder = "./runs/240703-195125-Test/"
# run_folder = "./runs/240717-201143-T3-INF/"
run_folder = None
generate_data = True

nb_train_starts = 12
nb_train_targets = 10
nb_adapt_targets = 16

save_trainer = True

nb_epochs_adapt = 5000
adapt_test = True
adapt_restore = False

integrator = diffrax.Dopri5
# integrator = RK4
ivp_args = {"dt_init":1e-4, "rtol":1e-3, "atol":1e-6, "max_steps":4000, "subdivisions":2}

#%%


if train == True and generate_data == True:

    # check that 'tmp' folder exists. If not, create it
    if not os.path.exists('./runs'):
        os.mkdir('./runs')

    # Make a new folder inside 'tmp' whose name is the current time

    if run_folder is None:
        run_folder = './runs/'+time.strftime("%y%m%d-%H%M%S")+'/'
    if not os.path.exists(run_folder):
        os.mkdir(run_folder)
    print("Run folder created successfuly:", run_folder)

    # Save the run and dataset scripts in that folder
    script_name = os.path.basename(__file__)
    os.system(f"cp {script_name} {run_folder}")

    # Save the nodax module files as well
    os.system(f"cp -r ../../idncflow {run_folder}")
    print("Completed copied scripts ")

else:
    print("No training. Loading data and results from:", run_folder)

## Create a folder for the adaptation results
adapt_folder = run_folder+"adapt/"
if not os.path.exists(adapt_folder):
    os.mkdir(adapt_folder)

# %%

mother_key = jax.random.PRNGKey(seed)

def make_dataset(traj_key, 
                 env_key=None, 
                 nb_trajs=nb_train_starts, 
                 nb_envs=nb_train_targets, 
                 vmin=-1., 
                 vmax=1.):
    starts = jax.random.uniform(traj_key, shape=(nb_trajs,2), minval=-1, maxval=1)

    ## Targets never change
    if env_key is None:
        targets = jax.random.uniform(mother_key, shape=(nb_envs,2), minval=vmin, maxval=vmax)
    else:
        targets = jax.random.uniform(env_key, shape=(nb_envs,2), minval=vmin, maxval=vmax)

    data = []
    for i in range(nb_envs):
        for j in range(nb_trajs):
            data.append(jnp.stack([starts[j], targets[i]]))
    dataset = jnp.stack(data, axis=0).reshape(nb_envs, nb_trajs, 2, 2)

    t_eval = jnp.array([0., 1.])

    return dataset, t_eval



if generate_data == True:
    start_key, _ = jax.random.split(mother_key)
    dataset, t_eval = make_dataset(start_key, mother_key, nb_trajs=nb_train_starts, nb_envs=nb_train_targets)

    print("Data shape:", dataset.shape)
    # print("Data:\n", dataset)
    print("Evaluation times:", t_eval)




#%%

## Define dataloader for training and validation
train_dataloader = DataLoader(dataset, t_eval=t_eval, batch_size=-1, shuffle=True, key=seed)

nb_envs = train_dataloader.nb_envs
nb_trajs_per_env = train_dataloader.nb_trajs_per_env
nb_steps_per_traj = train_dataloader.nb_steps_per_traj
data_size = train_dataloader.data_size

print(f"Stats: nb_envs={nb_envs}, nb_trajs_per_env={nb_trajs_per_env}, nb_steps_per_traj={nb_steps_per_traj}, data_size={data_size}")

start_key, _ = jax.random.split(start_key)
val_dat, _ = make_dataset(start_key, env_key=mother_key, nb_trajs=32, nb_envs=nb_train_targets)
val_dataloader = DataLoader(val_dat, t_eval=t_eval, shuffle=False)

#%%

## Define model and loss function for the learner


class Control(eqx.Module):
    layers_data: list
    layers_shared: list
    # activations: list
    ctx_utils:any

    def __init__(self, data_size, hidden_size, context_size, ctx_utils=None, key=None):
        self.ctx_utils = ctx_utils

        keys = jax.random.split(key, num=12)

        self.layers_data = [eqx.nn.Linear(1+data_size, hidden_size, key=keys[3]), jax.nn.softplus, 
                            # eqx.nn.Linear(hidden_size, hidden_size, key=keys[4]), jax.nn.softplus, 
                            eqx.nn.Linear(hidden_size, hidden_size, key=keys[5])]

        # self.layers_shared = [eqx.nn.Linear(1+data_size+context_size, hidden_size, key=keys[6]), jax.nn.softplus, 
        self.layers_shared = [eqx.nn.Linear(hidden_size+context_size, hidden_size, key=keys[6]), jax.nn.softplus, 
                              eqx.nn.Linear(hidden_size, hidden_size, key=keys[7]), jax.nn.softplus, 
                              eqx.nn.Linear(hidden_size, hidden_size, key=keys[8]), jax.nn.softplus, 
                              eqx.nn.Linear(hidden_size, 1, key=keys[9])]

    def __call__(self, t, x0, ctx_arr):

        ctx_shapes, ctx_treedef, ctx_static = self.ctx_utils
        ctx_params = unflatten_pytree(ctx_arr, ctx_shapes, ctx_treedef)
        ctx_fun = eqx.combine(ctx_params, ctx_static)

        t_arr = jnp.array([t])
        ctx = ctx_fun(t_arr)

        t_arr = jnp.array([t])
        u = jnp.concatenate([t_arr, x0], axis=0)
        for layer in self.layers_data:
            u = layer(u)

        u = jnp.concatenate([u, ctx], axis=0)
        # u = jnp.concatenate([t_arr, x0, ctx], axis=0)
        for layer in self.layers_shared:
            u = layer(u)

        return u


# class ContextFlowVectorField(eqx.Module):
#     control: eqx.Module

#     def __init__(self, control):
#         self.control = control

#     def __call__(self, t, x, args):
#         x0, ctx, ctx_ = args

#         A = jnp.array([[0., 1.], [1., 0.]])
#         B = jnp.array([[1.], [0.]])

#         vf = lambda xi: A@x + B@self.control(t, x0, xi)

#         gradvf = lambda xi_: eqx.filter_jvp(vf, (xi_,), (ctx-xi_,))[1]
#         scd_order_term = eqx.filter_jvp(gradvf, (ctx_,), (ctx-ctx_,))[1]

#         return vf(ctx_) + 1.5*gradvf(ctx_) + 0.5*scd_order_term


def loss_fn_ctx(model, trajs, t_eval, ctx, all_ctx_s, key):

    # ind = jax.random.randint(key, shape=(context_pool_size,), minval=0, maxval=all_ctx_s.shape[0])
    ind = jax.random.permutation(key, all_ctx_s.shape[0])[:context_pool_size]
    ctx_s = all_ctx_s[ind, :]

    trajs_hat, nb_steps = jax.vmap(model, in_axes=(None, None, None, 0))(trajs[:, 0, :], t_eval, ctx, ctx_s)
    new_trajs = jnp.broadcast_to(trajs, trajs_hat.shape)

    term1 = jnp.mean((new_trajs[...,-1,:]-trajs_hat[...,-1,:])**2)  ## reconstruction
    term2 = jnp.mean(jnp.abs(ctx))
    # term3 = params_norm_squared(model)

    # loss_val = term1 + 1e-3*term2 + 1e-3*term3
    # loss_val = jnp.nan_to_num(term1, nan=0.0, posinf=0.0, neginf=0.0)
    loss_val = term1

    return loss_val, (jnp.sum(nb_steps)/ctx_s.shape[0], term1, term2)







# contexts = ContextParams(nb_envs=nb_envs, context_size=context_size)
contexts = IDContextParams(nb_envs=nb_envs, context_size=context_size, hidden_size=32, depth=2, key=None)       ##TODO Randomize here at start !
control = Control(data_size=2, hidden_size=32, context_size=context_size, key=mother_key, ctx_utils=contexts.ctx_utils)
vectorfield = LinearControlVectorField(control, taylor_order=taylor_train_order)

learner = Learner(vectorfield, contexts, loss_fn_ctx, integrator, ivp_args, key=seed)

print("\n\nTotal number of parameters in the vector field:", sum(x.size for x in jax.tree_util.tree_leaves(eqx.filter(vectorfield,eqx.is_array)) if x is not None))
print("Total number of parameters in the contexts:", contexts.params.shape[0]*contexts.params.shape[1], "\n")


#%%

## Define optimiser and traine the model

bd_scales = {nb_outer_steps//3:sched_factor, 2*nb_outer_steps//3:sched_factor}

sched_node = optax.piecewise_constant_schedule(init_value=init_lr, boundaries_and_scales=bd_scales)
sched_ctx = optax.piecewise_constant_schedule(init_value=init_lr,boundaries_and_scales=bd_scales)

opt_node = optax.adam(sched_node)
opt_ctx = optax.adam(sched_ctx)

trainer = Trainer(train_dataloader, learner, (opt_node, opt_ctx), key=seed)

#%%

trainer_save_path = run_folder if save_trainer == True else False
if train == True:
    trainer.train_proximal(nb_outer_steps_max=nb_outer_steps, nb_inner_steps_max=nb_inner_steps_max, proximal_reg=proximal_beta, inner_tol_node=inner_tol_node, inner_tol_ctx=inner_tol_ctx,
                            print_error_every=print_error_every, save_path=trainer_save_path, key=seed, val_dataloader=val_dataloader, int_prop=1.0)

else:
    # print("\nNo training, attempting to load model and results from "+ run_folder +" folder ...\n")

    restore_folder = run_folder
    # restore_folder = "./runs/27012024-155719/finetune_193625/"
    trainer.restore_trainer(path=restore_folder)




#%%


# ## Test if context produces constant values (terminal states) 
# ts = jnp.linspace(0., 1., 100)
# vals = jax.vmap(trainer.learner.contexts)(ts)

# ax = sbplot(ts, vals[...,0], "ro-", x_label="Time", title="Contexts over time")
# sbplot(ts, vals[...,1], "g+-", x_label="Time", title="Contexts over time", ax=ax)








#%%

## Test and visualise the results on a test dataloader

test_dataloader = val_dataloader
visualtester = VisualTester(trainer)
# ans = visualtester.trainer.nb_steps_node
# print(ans.shape)

# criterion = lambda x, x_hat: jnp.mean((x[...,-1]-x_hat[...,-1])**2)
ind_crit = visualtester.test(test_dataloader, int_cutoff=1.0)

savefigdir = run_folder+"results_in_domain_"
for e in range(nb_train_targets):
    filename = savefigdir + f"env_{e}.png"
    visualtester.visualizeControl(test_dataloader, int_cutoff=1.0, save_path=filename, e=e);



#%%

























## Give the dataloader an id to help with restoration later on

if adapt_test and not adapt_restore:
    start_key, _ = jax.random.split(start_key)
    test_dat, _ = make_dataset(start_key, nb_trajs=1, env_key=jax.random.PRNGKey(seed+1), nb_envs=nb_adapt_targets, vmin=-2., vmax=2.)

if adapt_test:
    adapt_dataloader = DataLoader(test_dat, t_eval=t_eval, adaptation=True, data_id="170846", key=seed)

    sched_ctx_new = optax.piecewise_constant_schedule(init_value=init_lr, boundaries_and_scales=bd_scales)
    opt_adapt = optax.adabelief(sched_ctx_new)

    if adapt_restore == False:
        trainer.adapt(adapt_dataloader, 
                      nb_epochs=nb_epochs_adapt, 
                      optimizer=opt_adapt, 
                      print_error_every=print_error_every, 
                      taylor_order=taylor_adapt_order,
                      save_path=adapt_folder)
    else:
        print("Save_id for restoring trained adapation model:", adapt_dataloader.data_id)
        trainer.restore_adapted_trainer(path=adapt_folder, data_loader=adapt_dataloader)


#%%
if adapt_test:
    start_key, _ = jax.random.split(start_key)
    test_dat_adapt, _ = make_dataset(start_key, env_key=jax.random.PRNGKey(seed+1), nb_trajs=32, nb_envs=nb_adapt_targets, vmin=-2., vmax=2.)
    adapt_dataloader_test = DataLoader(test_dat_adapt, t_eval=t_eval, adaptation=True, data_id="170846", key=seed)

    ood_crit = visualtester.test(adapt_dataloader_test, int_cutoff=1.0)      ## It's the same visualtester as before during training. It knows trainer

    savefigdir = adapt_folder+"results_ood_"
    for e in range(nb_adapt_targets):
        filename = savefigdir + f"env_{e}.png"
        visualtester.visualizeControl(adapt_dataloader_test, int_cutoff=1.0, save_path=filename, e=e);




#%%
## If the nohup.log file exists, copy it to the run folder
try:
    __IPYTHON__ ## in a jupyter notebook
except NameError:
    if os.path.exists("nohup.log"):
        os.system(f"cp nohup.log {run_folder}")


#%%
