"""
Evaluate trained models on demonstration data.
"""
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import sys

sys.setrecursionlimit(4000)
from functools import partial
import time

import numpy as np
import jax
import jax.numpy as jnp
import einops as e

import flax
from flax.training import train_state
import optax as tx
import orbax.checkpoint as ocp

from frechetdist import frdist
import pyLasaDataset as lasa

import hydra
from omegaconf import DictConfig

from implicit_nonlinear_dynamics.model_architectures.feedforward import (
    init_feedforward_model,
)
from implicit_nonlinear_dynamics.evaluate.utils import (
    sample_within_threshold,
    deque_append,
    evaluate_mean_absolute_jerk,
)


def load_feedforward(cfg):

    # initialise model
    model, variables = init_feedforward_model(cfg["architecture"])

    # initalise optimiser for neural network params only
    def flattened_traversal(fn):
        def mask(tree):
            flat = flax.traverse_util.flatten_dict(tree)
            return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})

        return mask

    label_fn = flattened_traversal(lambda path, _: "none" if path[0] == "dynamics_params" else "adamw")

    lr_schedule = tx.warmup_cosine_decay_schedule(
        init_value=cfg["training"]["initial_lr"],
        peak_value=cfg["training"]["peak_lr"],
        warmup_steps=cfg["training"]["warmup_steps"],
        decay_steps=cfg["training"]["decay_steps"],
        end_value=cfg["training"]["end_lr"],
    )

    opt = tx.multi_transform(
        {
            "adamw": tx.adamw(learning_rate=lr_schedule),
            "none": tx.set_to_zero(),
        },
        label_fn,
    )

    opt = tx.chain(
        tx.clip_by_global_norm(cfg["training"]["clip_grad_norm"]),
        opt,
    )

    # create training state
    our_train_state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables,
        tx=opt,
    )

    # restore training state
    abstract_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, our_train_state)
    checkpointer = ocp.StandardCheckpointer()

    model_dir = os.path.join(cfg["training"]["model_save_dir"], cfg["evaluation"]["run_id"], "best_model")
    return checkpointer.restore(
        model_dir,
        args=ocp.args.StandardRestore(abstract_tree),
    )


def evaluate_prediction_latency(state, cfg):
    @partial(jax.jit, static_argnums=(1))
    def generate_predictions(state, cfg):
        state.apply_fn(state.params, jnp.ones(cfg["architecture"]["init_shapes"]["input"]))

    start = time.time()
    for _ in range(1000):
        generate_predictions(state, cfg)
    end = time.time()
    latency = (end - start) / 1000

    return latency


def evaluate_demo_feedforward(state, cfg, write_to_file=False):
    """
    Evaluate frechet distance against expert demonstrations, based on predicting position coordinates directly.
    """
    ACTION_CHUNKS = cfg["architecture"]["inference_params"]["action_chunks"]
    PREDICTION_HORIZON = cfg["architecture"]["inference_params"]["prediction_horizon"]
    NUM_PREDICTED_ACTIONS = cfg["architecture"]["inference_params"]["num_predicted_actions"]
    ENSEMBLE_WEIGHT = cfg["architecture"]["inference_params"]["ensemble_weight"]
    NOISE_SCALE = cfg["evaluation"]["noise_scale"]

    @partial(jax.jit, static_argnums=(1, 2, 3, 4))
    def generate_predictions(
        state, cfg, action_chunks=ACTION_CHUNKS, prediction_horizon=PREDICTION_HORIZON, ensemble_weight=ENSEMBLE_WEIGHT
    ):
        def compute_step(carry, x):
            env_state, horizon, is_first, key = carry
            key, subkey = jax.random.split(key)
            prediction = state.apply_fn(state.params, env_state)

            # add random noise to prediction
            prediction += jax.random.normal(subkey, prediction.shape) * NOISE_SCALE

            # perform temporal ensembling
            memory, rear, front, num_elements = deque_append(
                horizon["memory"],
                horizon["rear"],
                horizon["front"],
                horizon["n_elements"],
                jnp.squeeze(prediction),
                is_first,
            )

            updated_horizon = {
                "memory": memory,
                "front": front,
                "rear": rear,
                "n_elements": num_elements,
            }

            ensemble_weights = jnp.asarray(
                [
                    jnp.repeat(jnp.exp(-ensemble_weight * i), (ACTION_CHUNKS * 2), axis=-1)
                    for i in range(PREDICTION_HORIZON)
                ]
            )
            memory_filtered = jnp.asarray(
                [
                    memory[i, i * (NUM_PREDICTED_ACTIONS) : i * (NUM_PREDICTED_ACTIONS) + (NUM_PREDICTED_ACTIONS)]
                    for i in range(PREDICTION_HORIZON)
                ]
            )
            ensemble_weights_filtered = jnp.asarray(
                [
                    ensemble_weights[
                        i, i * (NUM_PREDICTED_ACTIONS) : i * (NUM_PREDICTED_ACTIONS) + (NUM_PREDICTED_ACTIONS)
                    ]
                    for i in range(PREDICTION_HORIZON)
                ]
            )
            output_prediction = jnp.sum(memory_filtered * ensemble_weights_filtered, axis=0) / jnp.sum(
                ensemble_weights_filtered, axis=0
            )

            return (output_prediction[-2:], updated_horizon, False, key), output_prediction[-2:]

        predicted_vals = []
        raw_data = lasa.DataSet.__getattr__(cfg["dataset"]["shape"])
        demos = raw_data.demos
        pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
        pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        env_state = jnp.array(pos[:, :, 0])
        key = jax.random.PRNGKey(0)
        keys = jax.random.split(key, pos.shape[0])

        dummy = jnp.zeros((pos.shape[0], ACTION_CHUNKS * 2))
        horizon = {
            "memory": jnp.zeros((jnp.shape(dummy)[0], PREDICTION_HORIZON, *jnp.shape(dummy)[1:])),
            "front": jnp.zeros(jnp.shape(dummy)[0]),
            "rear": jnp.ones(jnp.shape(dummy)[0]) * -1,
            "n_elements": jnp.zeros(jnp.shape(dummy)[0]),
        }

        carry, env_states = jax.vmap(jax.lax.scan, in_axes=(None, (0, 0, None, 0), None, None))(
            compute_step, (env_state, horizon, True, keys), None, 990
        )
        predicted_vals.append(env_states)

        return jnp.squeeze(jnp.asarray(predicted_vals)), jnp.asarray(pos)

    os.makedirs(
        f"{os.path.dirname(os.path.abspath(__file__))}/../.results/{cfg['wandb']['experiment_name']}", exist_ok=True
    )

    for ensemble_weight in [0.00001, 0.0001, 0.001, 0.01, 0.1]:
        predicted_vals, demo_vals = generate_predictions(state, cfg, ensemble_weight=ensemble_weight)

        fr_dists = []
        jerk_xs = []
        jerk_ys = []
        num_predictions = []
        for i in range(7):
            sample_predicted, sample_demo = sample_within_threshold(predicted_vals[i, :, -2:].T, demo_vals[i, :, :1000])
            frechet_dist = frdist(sample_predicted.T, sample_demo.T)
            jerk_x = evaluate_mean_absolute_jerk(sample_predicted[0])
            jerk_y = evaluate_mean_absolute_jerk(sample_predicted[1])
            fr_dists.append(frechet_dist)
            jerk_xs.append(jerk_x)
            jerk_ys.append(jerk_y)
            num_predictions.append(sample_predicted.shape[1])

        latency = evaluate_prediction_latency(state, cfg) * np.mean(num_predictions)
        demo_vals = np.asarray(demo_vals)
        predicted_vals = np.asarray(predicted_vals)
        fr_dists = np.asarray(fr_dists)
        latency = np.asarray(latency)
        jerk_xs = np.asarray(jerk_xs)
        jerk_ys = np.asarray(jerk_ys)

        if write_to_file:
            np.save(
                f"{os.path.dirname(os.path.abspath(__file__))}\
                    /../.results/{cfg['wandb']['experiment_name']}/temporal_demo_vals_{ensemble_weight}.npy",
                demo_vals,
            )
            np.save(
                f"{os.path.dirname(os.path.abspath(__file__))}\
                    /../.results/{cfg['wandb']['experiment_name']}/temporal_predicted_vals_{ensemble_weight}.npy",
                predicted_vals,
            )
            np.save(
                f"{os.path.dirname(os.path.abspath(__file__))}\
                    /../.results/{cfg['wandb']['experiment_name']}/temporal_fr_dists_{ensemble_weight}.npy",
                fr_dists,
            )
            np.save(
                f"{os.path.dirname(os.path.abspath(__file__))}\
                    /../.results/{cfg['wandb']['experiment_name']}/temporal_latency_{ensemble_weight}.npy",
                latency,
            )

    return demo_vals, predicted_vals, fr_dists, latency, jerk_xs, jerk_ys


def output_representations_feedforward(state, cfg, write_to_file=False):
    """
    Evaluate frechet distance against expert demonstrations, based on predicting position coordinates directly.
    """
    ACTION_CHUNKS = cfg["architecture"]["inference_params"]["action_chunks"]
    PREDICTION_HORIZON = cfg["architecture"]["inference_params"]["prediction_horizon"]
    NUM_PREDICTED_ACTIONS = cfg["architecture"]["inference_params"]["num_predicted_actions"]
    ENSEMBLE_WEIGHT = cfg["architecture"]["inference_params"]["ensemble_weight"]
    NOISE_SCALE = cfg["evaluation"]["noise_scale"]

    @partial(jax.jit, static_argnums=(1, 2, 3, 4))
    def generate_predictions(
        state, cfg, action_chunks=ACTION_CHUNKS, prediction_horizon=PREDICTION_HORIZON, ensemble_weight=ENSEMBLE_WEIGHT
    ):
        def compute_step(carry, x):
            env_state, horizon, is_first, key = carry
            key, subkey = jax.random.split(key)
            prediction, intermediates = state.apply_fn(
                state.params,
                env_state,
                capture_intermediates=True,
                mutable=["intermediates"],
            )

            # add random noise to prediction
            prediction += jax.random.normal(subkey, prediction.shape) * NOISE_SCALE

            # perform temporal ensembling
            memory, rear, front, num_elements = deque_append(
                horizon["memory"],
                horizon["rear"],
                horizon["front"],
                horizon["n_elements"],
                jnp.squeeze(prediction),
                is_first,
            )

            updated_horizon = {
                "memory": memory,
                "front": front,
                "rear": rear,
                "n_elements": num_elements,
            }

            weights = jnp.asarray(
                [
                    jnp.repeat(jnp.exp(-ensemble_weight * i), (ACTION_CHUNKS * 2), axis=-1)
                    for i in range(PREDICTION_HORIZON)
                ]
            )
            current_prediction = jnp.sum(memory * weights, axis=0) / jnp.sum(weights, axis=0)
            output_prediction = current_prediction[NUM_PREDICTED_ACTIONS - 2 : NUM_PREDICTED_ACTIONS]

            return (output_prediction, updated_horizon, False, key), intermediates["intermediates"]["Dense_1"][
                "__call__"
            ][0]

        raw_data = lasa.DataSet.__getattr__(cfg["dataset"]["shape"])
        demos = raw_data.demos
        pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
        pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        env_state = jnp.array(pos[:, :, 0])
        key = jax.random.PRNGKey(0)
        keys = jax.random.split(key, pos.shape[0])

        dummy = jnp.zeros((pos.shape[0], ACTION_CHUNKS * 2))
        horizon = {
            "memory": jnp.zeros((jnp.shape(dummy)[0], PREDICTION_HORIZON, *jnp.shape(dummy)[1:])),
            "front": jnp.zeros(jnp.shape(dummy)[0]),
            "rear": jnp.ones(jnp.shape(dummy)[0]) * -1,
            "n_elements": jnp.zeros(jnp.shape(dummy)[0]),
        }

        carry, representations = jax.vmap(jax.lax.scan, in_axes=(None, (0, 0, None, 0), None, None))(
            compute_step, (env_state, horizon, True, keys), None, 990
        )

        return jnp.squeeze(jnp.asarray(representations))

    representations = np.asarray(generate_predictions(state, cfg))

    if write_to_file:
        os.makedirs(
            f"{os.path.dirname(os.path.abspath(__file__))}\
                /../.results/{cfg['wandb']['experiment_name']}",
            exist_ok=True,
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}\
                /../.results/{cfg['wandb']['experiment_name']}/representations.npy",
            representations,
        )

    return representations


@hydra.main(version_base=None, config_path="..")
def main(cfg: DictConfig) -> None:
    cfg = cfg["config"]
    state = load_feedforward(cfg)
    evaluate_demo_feedforward(state, cfg)
    output_representations_feedforward(state, cfg)


if __name__ == "__main__":
    main()
