import os
import shutil
import sys

sys.setrecursionlimit(4000)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

from functools import partial
from copy import deepcopy

import jax
import jax.numpy as jnp
import einops as e

import flax
from flax.training import train_state
import optax as tx
import orbax.checkpoint as ocp


from frechetdist import frdist

import wandb
import hydra
import omegaconf
from omegaconf import DictConfig

import pyLasaDataset as lasa

from implicit_nonlinear_dynamics.data_preprocessing.lasa_handwriting import (
    process_lasa_multi_data,
    generate_character_image,
    process_lasa_multi_position_delta_data,
    process_lasa_multi_velocity_data,
)
from implicit_nonlinear_dynamics.evaluate.evaluate_ours_multitask import evaluate_demo_ours_multi
from implicit_nonlinear_dynamics.model_architectures.ours import init_our_multi_model
from implicit_nonlinear_dynamics.evaluate.utils import sample_within_threshold


@jax.jit
def train_step(model_state, env_states, img, target):
    def loss_fn(params):
        """
        Unrolls the neural network to generate predictions
        """
        predictions = []
        rc_state = jnp.zeros((1, 5000))

        def compute_step(rc_state_var, env_state_var):
            rc_state, img = rc_state_var
            prediction, rc_state = model_state.apply_fn(
                params, jnp.expand_dims(env_state_var, axis=0), rc_state, jnp.expand_dims(img, axis=0)
            )
            return (rc_state, img), prediction

        final_state, predictions = jax.vmap(jax.lax.scan, in_axes=(None, (None, 0), 0))(
            compute_step, (rc_state, img), env_states
        )
        return jnp.mean((jnp.squeeze(predictions) - target) ** 2)

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(model_state.params)
    model_state = model_state.apply_gradients(grads=grads)
    return model_state, loss


partial(jax.jit, static_argnums=(1))


def evaluate_demo_multi_pos_delta(state, cfg):
    """
    Evaluate frechet distance against expert demonstrations, based on predicting position deltas.
    """

    def compute_step(carry, x):
        dynamics_state, img, env_state = carry
        prediction, dynamics_state = state.apply_fn(
            state.params, jnp.expand_dims(env_state, axis=0), dynamics_state, jnp.expand_dims(img, axis=0)
        )
        # add predicted position deltas to current position to get next position
        x_delta = jnp.sum(prediction[0, ::2], axis=0)
        y_delta = jnp.sum(prediction[0, 1::2], axis=0)
        env_state = env_state + jnp.vstack([x_delta, y_delta]).T
        return (dynamics_state, img, env_state[0, :]), env_state

    fr_dists = []
    for dataset in cfg["dataset"]["shapes"]:
        raw_data = lasa.DataSet.__getattr__(dataset)
        demos = raw_data.demos
        pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
        pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        env_state = jnp.array(pos[:, :, 0])

        img = generate_character_image(dataset[0])
        dynamics_state = jnp.zeros((1, 5000))
        carry, env_states = jax.vmap(jax.lax.scan, in_axes=(None, (None, None, 0), None, None))(
            compute_step, (dynamics_state, img, env_state), None, 83
        )

        for i in range(7):
            frechet_dist = frdist(*sample_within_threshold(env_states[i, :, -2:].T, pos[i, :, :1000]))
            fr_dists.append(frechet_dist)

        return jnp.mean(jnp.asarray(fr_dists))


partial(jax.jit, static_argnums=(1))


def evaluate_demo_multi_vel(state, cfg):
    """
    Evaluate frechet distance against expert demonstrations, based on predicting velocity deltas.
    """
    TIMESTEP = 0.003

    def compute_step(carry, x):
        dynamics_state, img, env_state = carry
        prediction, dynamics_state = state.apply_fn(
            state.params, jnp.expand_dims(env_state, axis=0), dynamics_state, jnp.expand_dims(img, axis=0)
        )
        # compute new position based on velocity and timestep
        prediction *= TIMESTEP
        x_delta = jnp.sum(prediction[0, ::2], axis=0)
        y_delta = jnp.sum(prediction[0, 1::2], axis=0)

        # update the position components
        env_state = env_state.at[:2].set(env_state[:2] + jnp.squeeze(jnp.vstack([x_delta, y_delta]).T))
        # update the velocity components
        env_state = env_state.at[2:].set(jnp.squeeze(prediction)[-2:])

        return (dynamics_state, img, env_state), env_state

    fr_dists = []
    for dataset in cfg["dataset"]["shapes"]:
        raw_data = lasa.DataSet.__getattr__(dataset)
        demos = raw_data.demos
        pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
        pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        vel = jnp.hstack([demos[i].vel for i in range(len(demos))])
        vel = e.rearrange(vel, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        env_state = jnp.concatenate([jnp.array(pos[:, :, 0]), jnp.array(vel[:, :, 0])], axis=-1)

        img = generate_character_image(dataset[0])
        dynamics_state = jnp.zeros((1, 5000))
        carry, env_states = jax.vmap(jax.lax.scan, in_axes=(None, (None, None, 0), None, None))(
            compute_step, (dynamics_state, img, env_state), None, 83
        )

        for i in range(7):
            frechet_dist = frdist(*sample_within_threshold(env_states[i, :, :2].T, pos[i, :, :1000]))
            fr_dists.append(frechet_dist)

        return jnp.mean(jnp.asarray(fr_dists))


@hydra.main(version_base=None, config_path="..")
def main(cfg: DictConfig) -> None:
    cfg = cfg["config"]

    # experiment tracking
    wandb.init(
        project=cfg["wandb"]["project"],
        entity=cfg["wandb"]["entity"],
        config=omegaconf.OmegaConf.to_container(cfg, resolve=False),
        tags=cfg["wandb"]["tags"],
        notes=cfg["wandb"]["notes"],
    )
    wandb.run.name = cfg["wandb"]["experiment_name"]

    # initialise model
    model, variables = init_our_multi_model(cfg["architecture"])

    # initalise optimiser for neural network params only
    def flattened_traversal(fn):
        def mask(tree):
            flat = flax.traverse_util.flatten_dict(tree)
            return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})

        return mask

    label_fn = flattened_traversal(lambda path, _: "none" if path[0] == "dynamics_params" else "adamw")

    lr_schedule = tx.warmup_cosine_decay_schedule(
        init_value=cfg["training"]["initial_lr"],
        peak_value=cfg["training"]["peak_lr"],
        warmup_steps=cfg["training"]["warmup_steps"],
        decay_steps=cfg["training"]["decay_steps"],
        end_value=cfg["training"]["end_lr"],
    )

    opt = tx.multi_transform(
        {
            "adamw": tx.adamw(learning_rate=lr_schedule),
            "none": tx.set_to_zero(),
        },
        label_fn,
    )

    opt = tx.chain(
        tx.clip_by_global_norm(cfg["training"]["clip_grad_norm"]),
        opt,
    )

    # create training state
    our_train_state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables,
        tx=opt,
    )

    # model checkpointing
    os.makedirs(cfg["training"]["model_save_dir"], exist_ok=True)
    path = ocp.test_utils.erase_and_create_empty(cfg["training"]["model_save_dir"] + f"/{wandb.run.id}")
    options = ocp.CheckpointManagerOptions(max_to_keep=2, save_interval_steps=200)
    mngr = ocp.CheckpointManager(
        path,
        options=options,
    )

    # our_train_state, loss = train_step(our_train_state, input_data, target_data)
    if cfg["architecture"]["action_space_type"] == "position":
        input_data, target_data, img_data = process_lasa_multi_data(
            cfg["dataset"]["shapes"], window_size=cfg["dataset"]["window_size"], stride=cfg["dataset"]["stride"]
        )
    elif cfg["architecture"]["action_space_type"] == "position_delta":
        input_data, target_data, img_data = process_lasa_multi_position_delta_data(cfg["dataset"]["shapes"])
    elif cfg["architecture"]["action_space_type"] == "velocity":
        input_data, target_data, img_data = process_lasa_multi_velocity_data(cfg["dataset"]["shapes"])
    else:
        raise ValueError("Invalid action space type")

    # begin training loop
    BEST_MODEL = {}
    min_frechet_dist = 1e6
    for i in range(cfg["training"]["num_epochs"]):
        our_train_state, loss = train_step(our_train_state, input_data, img_data, target_data)
        print(f"Loss for epoch: {i}: {loss}")

        if (i % 200 == 0) and (i != 0):
            if cfg["architecture"]["action_space_type"] == "position":
                _, _, frechet_dists, _ = evaluate_demo_ours_multi(our_train_state, cfg)
                frechet_dist = jnp.mean(jnp.asarray(frechet_dists))
            elif cfg["architecture"]["action_space_type"] == "position_delta":
                frechet_dist = evaluate_demo_multi_pos_delta(our_train_state, cfg)
            elif cfg["architecture"]["action_space_type"] == "velocity":
                frechet_dist = evaluate_demo_multi_vel(our_train_state, cfg)
            else:
                raise ValueError("Invalid action space type")
            wandb.log({"train_loss": loss, "frechet_dist": frechet_dist, "epoch": i, "lr": lr_schedule(i)})
            if frechet_dist < min_frechet_dist:
                min_frechet_dist = frechet_dist
                BEST_MODEL = deepcopy(our_train_state)
                # log best model as artifact
                best_model_path = os.path.join(cfg["training"]["model_save_dir"], wandb.run.id, "best_model")
                if os.path.exists(best_model_path):
                    shutil.rmtree(best_model_path)
                checkpointer = ocp.StandardCheckpointer()
                checkpointer.save(best_model_path, BEST_MODEL)
        else:
            wandb.log({"train_loss": loss, "epoch": i, "lr": lr_schedule(i)})

        mngr.save(
            i,
            args=ocp.args.StandardSave(our_train_state),
        )
    mngr.wait_until_finished()

    # uncomment to upload model to hugging face
    # print("Uploading model to hugging face...")
    # args = [f"{v}" for k, v in cfg["huggingface"].items()]
    # process = Popen(["python", f"{os.path.dirname(os.path.abspath(__file__))}/../huggingface/upload_model.py"] + args)
    # process.wait()


if __name__ == "__main__":
    main()
