"""
Evaluate trained models on demonstration data.
"""
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import sys

sys.setrecursionlimit(4000)
from functools import partial
import time

import numpy as np
import jax
import jax.numpy as jnp
import einops as e

import flax
from flax.training import train_state
import optax as tx
import orbax.checkpoint as ocp

from frechetdist import frdist
import pyLasaDataset as lasa

import hydra
from omegaconf import DictConfig

from implicit_nonlinear_dynamics.model_architectures.esn import init_esn_model
from implicit_nonlinear_dynamics.evaluate.utils import sample_within_threshold, evaluate_mean_absolute_jerk


def load_esn(cfg):
    # initialise model
    model, variables = init_esn_model(cfg["architecture"])

    # initalise optimiser for neural network params only
    def flattened_traversal(fn):
        def mask(tree):
            flat = flax.traverse_util.flatten_dict(tree)
            return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})

        return mask

    label_fn = flattened_traversal(lambda path, _: "none" if path[0] == "dynamics_params" else "adamw")

    lr_schedule = tx.warmup_cosine_decay_schedule(
        init_value=cfg["training"]["initial_lr"],
        peak_value=cfg["training"]["peak_lr"],
        warmup_steps=cfg["training"]["warmup_steps"],
        decay_steps=cfg["training"]["decay_steps"],
        end_value=cfg["training"]["end_lr"],
    )

    opt = tx.multi_transform(
        {
            "adamw": tx.adamw(learning_rate=lr_schedule),
            "none": tx.set_to_zero(),
        },
        label_fn,
    )

    opt = tx.chain(
        tx.clip_by_global_norm(cfg["training"]["clip_grad_norm"]),
        opt,
    )

    # create training state
    our_train_state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables,
        tx=opt,
    )

    # restore training state
    abstract_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, our_train_state)
    checkpointer = ocp.StandardCheckpointer()

    model_dir = os.path.join(cfg["training"]["model_save_dir"], cfg["evaluation"]["run_id"], "best_model")
    state = checkpointer.restore(
        model_dir,
        args=ocp.args.StandardRestore(abstract_tree),
    )

    gpus = jax.devices("gpu")
    params = jax.device_put(state.params, gpus[0])
    state = state.replace(params=params)

    return state


def evaluate_prediction_latency(state, cfg):
    @partial(jax.jit, static_argnums=(1))
    def generate_predictions(state, cfg):
        state.apply_fn(
            state.params,
            jnp.ones(cfg["architecture"]["init_shapes"]["input"]),
            jnp.ones(cfg["architecture"]["init_shapes"]["state"]),
        )

    start = time.time()
    for _ in range(1000):
        generate_predictions(state, cfg)
    end = time.time()
    latency = (end - start) / 1000

    return latency


def evaluate_demo_esn(state, cfg, write_to_file=False):
    """
    Evaluate frechet distance against expert demonstrations, based on predicting position coordinates directly.
    """
    NUM_PREDICTED_ACTIONS = cfg["architecture"]["inference_params"]["num_predicted_actions"]
    NOISE_SCALE = cfg["evaluation"]["noise_scale"]

    @partial(jax.jit, static_argnums=(1))
    def generate_predictions(state, cfg):
        def compute_step(carry, x):
            dynamics_state, env_state, key = carry
            key, subkey = jax.random.split(key)
            prediction, dynamics_state = state.apply_fn(
                state.params, jnp.expand_dims(env_state, axis=0), dynamics_state
            )
            prediction += jax.random.normal(subkey, prediction.shape) * NOISE_SCALE
            output_prediction = prediction[0, (NUM_PREDICTED_ACTIONS - 2) : NUM_PREDICTED_ACTIONS]
            return (dynamics_state, output_prediction, key), output_prediction

        predicted_vals = []
        raw_data = lasa.DataSet.__getattr__(cfg["dataset"]["shape"])
        demos = raw_data.demos
        pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
        pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        env_state = jnp.array(pos[:, :, 0])
        dynamics_state = jnp.zeros((1, 5000))
        key = jax.random.PRNGKey(0)
        keys = jax.random.split(key, pos.shape[0])
        carry, env_states = jax.vmap(jax.lax.scan, in_axes=(None, (None, 0, 0), None, None))(
            compute_step, (dynamics_state, env_state, keys), None, 100
        )
        predicted_vals.append(env_states)

        return jnp.squeeze(jnp.asarray(predicted_vals)), jnp.asarray(pos)

    predicted_vals, demo_vals = generate_predictions(state, cfg)

    fr_dists = []
    jerk_xs = []
    jerk_ys = []
    num_predictions = []
    for i in range(7):
        sample_predicted, sample_demo = sample_within_threshold(predicted_vals[i, :, -2:].T, demo_vals[i, :, :1000])
        frechet_dist = frdist(sample_predicted.T, sample_demo.T)
        jerk_x = evaluate_mean_absolute_jerk(sample_predicted[0])
        jerk_y = evaluate_mean_absolute_jerk(sample_predicted[1])
        fr_dists.append(frechet_dist)
        jerk_xs.append(jerk_x)
        jerk_ys.append(jerk_y)
        num_predictions.append(sample_predicted.shape[1])

    latency = evaluate_prediction_latency(state, cfg) * np.mean(num_predictions)
    demo_vals = np.asarray(demo_vals)
    predicted_vals = np.asarray(predicted_vals)
    fr_dists = np.asarray(fr_dists)
    latency = np.asarray(latency)
    jerk_xs = np.asarray(jerk_xs)
    jerk_ys = np.asarray(jerk_ys)

    if write_to_file:
        os.makedirs(
            f"{os.path.dirname(os.path.abspath(__file__))}/../.results/{cfg['wandb']['experiment_name']}", exist_ok=True
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}/../.results/{cfg['wandb']['experiment_name']}/demo_vals.npy",
            demo_vals,
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}/../.results/\
                {cfg['wandb']['experiment_name']}/predicted_vals.npy",
            predicted_vals,
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}/../.results/{cfg['wandb']['experiment_name']}/fr_dists.npy",
            fr_dists,
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}/../.results/{cfg['wandb']['experiment_name']}/latency.npy",
            latency,
        )

    return demo_vals, predicted_vals, fr_dists, latency, jerk_xs, jerk_ys


def output_representations_esn(state, cfg, write_to_file=False):
    """
    Evaluate frechet distance against expert demonstrations, based on predicting position coordinates directly.
    """
    NUM_PREDICTED_ACTIONS = cfg["architecture"]["inference_params"]["num_predicted_actions"]
    NOISE_SCALE = cfg["evaluation"]["noise_scale"]

    @partial(jax.jit, static_argnums=(1))
    def generate_predictions(state, cfg):
        def compute_step(carry, x):
            dynamics_state, env_state, key = carry
            key, subkey = jax.random.split(key)
            (prediction, dynamics_state), intermediates = state.apply_fn(
                state.params,
                jnp.expand_dims(env_state, axis=0),
                dynamics_state,
                capture_intermediates=True,
                mutable=["intermediates"],
            )
            prediction += jax.random.normal(subkey, prediction.shape) * NOISE_SCALE
            output_prediction = prediction[0, (NUM_PREDICTED_ACTIONS - 2) : NUM_PREDICTED_ACTIONS]
            return (dynamics_state, output_prediction, key), intermediates["intermediates"]["ESN_0"]["__call__"][0]

        raw_data = lasa.DataSet.__getattr__(cfg["dataset"]["shape"])
        demos = raw_data.demos
        pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
        pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
        env_state = jnp.array(pos[:, :, 0])
        dynamics_state = jnp.zeros((1, 5000))
        key = jax.random.PRNGKey(0)
        keys = jax.random.split(key, pos.shape[0])
        carry, representations = jax.vmap(jax.lax.scan, in_axes=(None, (None, 0, 0), None, None))(
            compute_step, (dynamics_state, env_state, keys), None, 100
        )

        return jnp.squeeze(representations)

    representations = generate_predictions(state, cfg)
    representations = np.asarray(representations)

    if write_to_file:
        os.makedirs(
            f"{os.path.dirname(os.path.abspath(__file__))}\
                /../.results/{cfg['wandb']['experiment_name']}",
            exist_ok=True,
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}\
                /../.results/{cfg['wandb']['experiment_name']}/representations.npy",
            representations,
        )

    return representations


@hydra.main(version_base=None, config_path="..")
def main(cfg: DictConfig) -> None:
    cfg = cfg["config"]
    state = load_esn(cfg)
    evaluate_demo_esn(state, cfg)
    output_representations_esn(state, cfg)


if __name__ == "__main__":
    main()
