import os
import shutil
import sys

sys.setrecursionlimit(4000)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

from copy import deepcopy

import jax
import jax.random as random
import jax.numpy as jnp
import einops as e

import flax
from flax.training import train_state
import optax as tx
import orbax.checkpoint as ocp

import numpy as np

import wandb
import hydra
import omegaconf
from omegaconf import DictConfig

from implicit_nonlinear_dynamics.data_preprocessing.lasa_handwriting import (
    process_lasa_multi_data,
)
from implicit_nonlinear_dynamics.evaluate.evaluate_feedforward_multitask import evaluate_demo_feedforward_multi
from implicit_nonlinear_dynamics.model_architectures.feedforward import (
    init_feedforward_multi_model,
)


@jax.jit
def train_step(model_state, state_data, img_data, control_data):
    def loss_fn(params):
        predicted_control = model_state.apply_fn(params, state_data, img_data)
        loss = jnp.mean((jnp.squeeze(predicted_control) - control_data) ** 2)
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(model_state.params)
    model_state = model_state.apply_gradients(grads=grads)
    return model_state, loss


@hydra.main(version_base=None, config_path="..")
def main(cfg: DictConfig) -> None:
    cfg = cfg["config"]

    # experiment tracking
    wandb.init(
        project=cfg["wandb"]["project"],
        entity=cfg["wandb"]["entity"],
        config=omegaconf.OmegaConf.to_container(cfg, resolve=False),
        tags=cfg["wandb"]["tags"],
        notes=cfg["wandb"]["notes"],
    )
    wandb.run.name = cfg["wandb"]["experiment_name"]

    # initialise model
    model, variables = init_feedforward_multi_model(cfg["architecture"])

    # initalise optimiser for neural network params only
    def flattened_traversal(fn):
        def mask(tree):
            flat = flax.traverse_util.flatten_dict(tree)
            return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})

        return mask

    label_fn = flattened_traversal(lambda path, _: "none" if path[0] == "dynamics_params" else "adamw")

    lr_schedule = tx.warmup_cosine_decay_schedule(
        init_value=cfg["training"]["initial_lr"],
        peak_value=cfg["training"]["peak_lr"],
        warmup_steps=cfg["training"]["warmup_steps"],
        decay_steps=cfg["training"]["decay_steps"],
        end_value=cfg["training"]["end_lr"],
    )

    opt = tx.multi_transform(
        {
            "adamw": tx.adamw(learning_rate=lr_schedule),
            "none": tx.set_to_zero(),
        },
        label_fn,
    )

    opt = tx.chain(
        tx.clip_by_global_norm(cfg["training"]["clip_grad_norm"]),
        opt,
    )

    # create training state
    our_train_state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables,
        tx=opt,
    )

    # model checkpointing
    os.makedirs(cfg["training"]["model_save_dir"], exist_ok=True)
    path = ocp.test_utils.erase_and_create_empty(cfg["training"]["model_save_dir"] + f"/{wandb.run.id}")
    options = ocp.CheckpointManagerOptions(max_to_keep=2, save_interval_steps=200)
    mngr = ocp.CheckpointManager(
        path,
        options=options,
    )

    # load data
    input_data, target_data, img_data = process_lasa_multi_data(
        cfg["dataset"]["shapes"], window_size=cfg["dataset"]["window_size"], stride=cfg["dataset"]["stride"]
    )

    # begin training loop
    key = random.key(42)
    BEST_MODEL = {}
    min_frechet_dist = 1e6
    for i in range(cfg["training"]["num_epochs"]):
        # shuffle and batch data
        key = random.fold_in(key, i)

        # TODO: need to improve the below code
        # randomly samply batches
        data, _ = e.pack([input_data, target_data], "demos segments *")
        img_data_ = jnp.repeat(jnp.expand_dims(img_data, axis=1), data.shape[1], axis=1)
        img_data_ = e.rearrange(img_data_, "demos segments h w c -> (demos segments) h w c")
        data = e.rearrange(data, "demos segments datapoints -> (demos segments) datapoints")
        data = random.permutation(key, data, axis=0)
        img_data_ = random.permutation(key, img_data_, axis=0)
        num_batches = (img_data_.shape[0] // cfg["training"]["batch_size"]) * cfg["training"]["batch_size"]
        img_data_ = e.rearrange(
            img_data_[:num_batches, :, :, :],
            "(samples batch_size) h w c -> samples batch_size h w c",
            batch_size=cfg["training"]["batch_size"],
        )
        batch_data = e.rearrange(
            data[:num_batches, :],
            "(samples batch_size) datapoints -> samples batch_size datapoints",
            batch_size=cfg["training"]["batch_size"],
        )
        input_data_ = batch_data[:, :, :2]
        target_data_ = batch_data[:, :, 2:]

        # TODO: add for loop for batching
        losses = []
        for j in range(input_data_.shape[0]):
            our_train_state, loss = train_step(
                our_train_state, input_data_[j, :, :], img_data_[j, :, :, :, :], target_data_[j, :, :]
            )
            losses.append(loss)
        print(f"Loss for epoch: {i}: {np.mean(losses)}")
        if (i % 200 == 0) and (i != 0):
            _, _, frechet_dists, _ = evaluate_demo_feedforward_multi(our_train_state, cfg)
            frechet_dist = jnp.mean(jnp.asarray(frechet_dists))
            wandb.log({"train_loss": np.mean(losses), "frechet_dist": frechet_dist, "epoch": i, "lr": lr_schedule(i)})
            if frechet_dist < min_frechet_dist:
                min_frechet_dist = frechet_dist
                BEST_MODEL = deepcopy(our_train_state)
                # log best model as artifact
                best_model_path = os.path.join(cfg["training"]["model_save_dir"], wandb.run.id, "best_model")
                if os.path.exists(best_model_path):
                    shutil.rmtree(best_model_path)
                checkpointer = ocp.StandardCheckpointer()
                checkpointer.save(best_model_path, BEST_MODEL)
        else:
            wandb.log({"train_loss": np.mean(losses), "epoch": i, "lr": lr_schedule(i)})

        mngr.save(
            i,
            args=ocp.args.StandardSave(our_train_state),
        )
    mngr.wait_until_finished()

    # uncomment to upload model to hugging face
    # print("Uploading model to hugging face...")
    # args = [f"{v}" for k, v in cfg["huggingface"].items()]
    # process = Popen(["python", f"{os.path.dirname(os.path.abspath(__file__))}/../huggingface/upload_model.py"] + args)
    # process.wait()


if __name__ == "__main__":
    main()
