import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("--latent_dim", type=int, default=4)
parser.add_argument("--width", type=int, default=16)
parser.add_argument("--activation", type=str, default="sin")
parser.add_argument("--epochs", type=int, default=5000)
parser.add_argument("--dataset", type=str, default="singleIC")
parser.add_argument("--prefix", type=str, default="")
parser.add_argument("--seed", type=int, default=101)
parser.add_argument("--loss", type=str, default="nmse")
parser.add_argument("--init_lr", type=float, default=1e-3)
parser.add_argument("--final_lr", type=float, default=2e-5)
parser.add_argument("--decay_steps", type=int, default=1000)
parser.add_argument("--decay_rate", type=float, default=0.975)
parser.add_argument("--num_samples", type=int, default=8)
parser.add_argument("--autodecoder", action=argparse.BooleanOptionalAction, default=True, help="Use autodecoder")
parser.add_argument("--max_step", type=int, default=100)
parser.add_argument("--split_start", type=int, default=10)
parser.add_argument("--decoder_arch", type=str, default="mlp", choices=['mlp', 'hyper'])
parser.add_argument("--normalize", action=argparse.BooleanOptionalAction, default=True)


args = parser.parse_args()
latent_dim = args.latent_dim
width_scale = args.width
activation = args.activation
epochs = args.epochs
dataset_name = args.dataset
prefix = args.prefix
seed = args.seed
loss = args.loss
init_lr = args.init_lr
final_lr = args.final_lr
decay_steps = args.decay_steps
decay_rate = args.decay_rate
num_samples = args.num_samples
autodecoder = args.autodecoder
max_step = args.max_step
split_start = args.split_start
arch = args.decoder_arch
normalize = args.normalize

import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import optax as optx
from PHIROM.modules.models import CROMOffline, DecoderArchEnum
from PHIROM.pde.data_utils import TrajectoryDataset, TimeWindowDatasetNPZ, NumpyLoader, JaxLoader
from functools import partial
from jaxtyping import PRNGKeyArray, Array
from PHIROM.training.baseline import CROMOfflineTrainer, CROMAutoDecoderTrainer
from PHIROM.utils.serial import make_CROMOffline, save_model, load_model
from PHIROM.training.callbacks import CheckpointCallback
from pathlib import Path
from datetime import datetime
from PHIROM.pde.burgers import BurgersDatasetTorch
from torch.utils.data import DataLoader

name = f"burgers_{dataset_name}_seed={seed}_loss={loss}_nt={max_step}_n={num_samples}_ac={activation}_ld={latent_dim}_ws={width_scale}_ep={epochs}"

if prefix:
    name = prefix + "_" + name

if dataset_name == "singleIC":
    path = "data/burgers.h5"
else:
    raise ValueError("Dataset not supported")
dataset_name = "Burgers_" + dataset_name

crop_bnd = False

if autodecoder:
    dataset_train = BurgersDatasetTorch(path, max_step, 1, 0, (0, num_samples), True, crop_bnd)
else:
    dataset_train = BurgersDatasetTorch(path, max_step, 1, 0, (0, num_samples), False, crop_bnd)
loader_train = DataLoader(dataset_train, batch_size=800, shuffle=False)

MEAN, STD = dataset_train.compute_mean_std_fields()
nx = dataset_train.X.shape[0]

hyperparams = {"latent_dim": latent_dim, "num_sensors": nx, "field_dim": 1, 
               "spatial_dim": 1, "mean_field": MEAN if normalize else None, "std_field": STD if normalize else None,
               "activation": activation}

if arch == "mlp":
    arch = DecoderArchEnum.MLP
    hyperparams["width_scale"] = width_scale
    hyperparams["decoder_arch"] = arch
elif arch == "hyper":
    arch = DecoderArchEnum.HYPER
    hyperparams["decoder_arch"] = arch
    hyperparams["width"] = width_scale
    hyperparams['n_layers'] = 4
    hyperparams['input_scale'] = 1.0

if activation in ["softplus", "elu", "swish", "tanh"]:
    mean_x, std_x = dataset_train.compute_mean_std_coords()
    mean_x = np.array([mean_x])
    std_x = np.array([std_x])
    hyperparams["mean_x"] = mean_x
    hyperparams["std_x"] = std_x
elif activation == "sin":
    min_x, max_x = dataset_train.compute_min_max_coords()
    min_x = np.array([min_x])
    max_x = np.array([max_x])
    hyperparams["min_x"] = min_x
    hyperparams["max_x"] = max_x

key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)

model = CROMOffline(**hyperparams, key=subkey)
scheduler = optx.schedules.exponential_decay(init_lr, decay_steps, decay_rate, end_value=final_lr, staircase=True)
optimizer =optx.adam(scheduler)

if autodecoder:
    method = "AD"
else:
    method = "AE"
evolve = "CROM"

path = os.path.join(dataset_name, method, arch, evolve)
path_experiment = os.path.join("CROM_experiments", path, name)
Path(path_experiment).mkdir(parents=True, exist_ok=True)
path_checkpoint = os.path.join(path_experiment, "checkpoints")
callbacks = [CheckpointCallback(path_checkpoint, name, hyperparams, True, 500)]

key, subkey = jax.random.split(key)

if autodecoder:
    scheduler_latent = optx.schedules.exponential_decay(init_lr, decay_steps, decay_rate, end_value=final_lr, staircase=False)
    optimimzer_latent = optx.adam(scheduler_latent)
    trainer = CROMAutoDecoderTrainer(model, optimizer, optimimzer_latent, evolve_time=False,
                                     loss=loss, callbacks=callbacks, key=subkey, num_trajectories=num_samples, num_time_steps=max_step,
                                     latent_dim=latent_dim)
else:
    raise ValueError("Not implemented")
model, opt_state, history = trainer.fit(loader_train, epochs=epochs, warm_start=True)

save_model(os.path.join(path_experiment, "model.eqx"), hyperparams, model, None)
history["loss_reconstruction"] = np.array(history["loss_reconstruction"])
history["loss_time_stepping"] = np.array(history["loss_time_stepping"])
np.savez(os.path.join(path_experiment, "history.npz"), **history)

if autodecoder:
    l = np.array(trainer.latent_memory)
    np.save(os.path.join(path_experiment, "latent_memory.npy"), l)

