import os
import shutil
import sys

sys.setrecursionlimit(4000)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

from copy import deepcopy

import jax
import jax.numpy as jnp

import flax
from flax.training import train_state
import optax as tx
import orbax.checkpoint as ocp


import wandb
import hydra
import omegaconf
from omegaconf import DictConfig

# dataset
from implicit_nonlinear_dynamics.data_preprocessing.lasa_handwriting import process_lasa_data_lagged
from implicit_nonlinear_dynamics.evaluate.evaluate_ours import evaluate_demo_ours
from implicit_nonlinear_dynamics.model_architectures.ours import init_our_model


@jax.jit
def train_step(model_state, env_states, target):
    def loss_fn(params):
        """
        Unrolls the neural network to generate predictions
        """
        predictions = []
        rc_state = jnp.zeros((1, 5000))

        def compute_step(rc_state_var, env_state_var):
            prediction, rc_state = model_state.apply_fn(params, jnp.expand_dims(env_state_var, axis=0), rc_state_var)
            return rc_state, prediction

        final_state, predictions = jax.vmap(jax.lax.scan, in_axes=(None, None, 0))(compute_step, rc_state, env_states)
        return jnp.mean((jnp.squeeze(predictions) - target) ** 2)

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(model_state.params)
    model_state = model_state.apply_gradients(grads=grads)
    return model_state, loss


@hydra.main(version_base=None, config_path="..")
def main(cfg: DictConfig) -> None:
    cfg = cfg["config"]

    # experiment tracking
    wandb.init(
        project=cfg["wandb"]["project"],
        entity=cfg["wandb"]["entity"],
        config=omegaconf.OmegaConf.to_container(cfg, resolve=False),
        tags=cfg["wandb"]["tags"],
        notes=cfg["wandb"]["notes"],
    )
    wandb.run.name = cfg["wandb"]["experiment_name"]

    # initialise model
    model, variables = init_our_model(cfg["architecture"])

    # initalise optimiser for neural network params only
    def flattened_traversal(fn):
        def mask(tree):
            flat = flax.traverse_util.flatten_dict(tree)
            return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})

        return mask

    label_fn = flattened_traversal(lambda path, _: "none" if path[0] == "dynamics_params" else "adamw")

    lr_schedule = tx.warmup_cosine_decay_schedule(
        init_value=cfg["training"]["initial_lr"],
        peak_value=cfg["training"]["peak_lr"],
        warmup_steps=cfg["training"]["warmup_steps"],
        decay_steps=cfg["training"]["decay_steps"],
        end_value=cfg["training"]["end_lr"],
    )

    opt = tx.multi_transform(
        {
            "adamw": tx.adamw(learning_rate=lr_schedule),
            "none": tx.set_to_zero(),
        },
        label_fn,
    )

    opt = tx.chain(
        tx.clip_by_global_norm(cfg["training"]["clip_grad_norm"]),
        opt,
    )

    # create training state
    our_train_state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables,
        tx=opt,
    )

    # model checkpointing
    os.makedirs(cfg["training"]["model_save_dir"], exist_ok=True)
    path = ocp.test_utils.erase_and_create_empty(cfg["training"]["model_save_dir"] + f"/{wandb.run.id}")
    options = ocp.CheckpointManagerOptions(max_to_keep=2, save_interval_steps=200)
    mngr = ocp.CheckpointManager(
        path,
        options=options,
    )

    # load data
    input_data, target_data = process_lasa_data_lagged(
        cfg["dataset"]["shape"], cfg["dataset"]["window_size"], cfg["dataset"]["stride"]
    )

    # begin training loop
    BEST_MODEL = {}
    min_frechet_dist = 1e6
    for i in range(cfg["training"]["num_epochs"]):
        our_train_state, loss = train_step(our_train_state, input_data, target_data)
        print(f"Loss for epoch: {i}: {loss}")

        if (i % 200 == 0) and (i != 0):
            _, _, frechet_dists, _, _, _ = evaluate_demo_ours(our_train_state, cfg)
            frechet_dist = jnp.mean(jnp.asarray(frechet_dists))
            wandb.log({"train_loss": loss, "frechet_dist": frechet_dist, "epoch": i, "lr": lr_schedule(i)})
            if frechet_dist < min_frechet_dist:
                min_frechet_dist = frechet_dist
                BEST_MODEL = deepcopy(our_train_state)
                # log best model as artifact
                best_model_path = os.path.join(cfg["training"]["model_save_dir"], wandb.run.id, "best_model")
                if os.path.exists(best_model_path):
                    shutil.rmtree(best_model_path)
                checkpointer = ocp.StandardCheckpointer()
                checkpointer.save(best_model_path, BEST_MODEL)
        else:
            wandb.log({"train_loss": loss, "epoch": i, "lr": lr_schedule(i)})

        mngr.save(
            i,
            args=ocp.args.StandardSave(our_train_state),
        )
    mngr.wait_until_finished()

    # uncomment to upload model to hugging face
    # print("Uploading model to hugging face...")
    # args = [f"{v}" for k, v in cfg["huggingface"].items()]
    # process = Popen(["python", f"{os.path.dirname(os.path.abspath(__file__))}/../huggingface/upload_model.py"] + args)
    # process.wait()


if __name__ == "__main__":
    main()
