"""
Evaluate trained models on demonstration data.
"""
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import sys

sys.setrecursionlimit(4000)

from functools import partial
import time

import numpy as np
import jax
import jax.numpy as jnp
import einops as e

import flax
from flax.training import train_state
import optax as tx
import orbax.checkpoint as ocp

from frechetdist import frdist
import pyLasaDataset as lasa

import hydra
from omegaconf import DictConfig

from implicit_nonlinear_dynamics.data_preprocessing.lasa_handwriting import generate_character_image
from implicit_nonlinear_dynamics.model_architectures.feedforward import (
    init_feedforward_multi_model,
)
from implicit_nonlinear_dynamics.evaluate.utils import sample_within_threshold, deque_append


def load_feedforward_multi(cfg):

    # initialise model
    model, variables = init_feedforward_multi_model(cfg["architecture"])

    # initalise optimiser for neural network params only
    def flattened_traversal(fn):
        def mask(tree):
            flat = flax.traverse_util.flatten_dict(tree)
            return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})

        return mask

    label_fn = flattened_traversal(lambda path, _: "none" if path[0] == "dynamics_params" else "adamw")

    lr_schedule = tx.warmup_cosine_decay_schedule(
        init_value=cfg["training"]["initial_lr"],
        peak_value=cfg["training"]["peak_lr"],
        warmup_steps=cfg["training"]["warmup_steps"],
        decay_steps=cfg["training"]["decay_steps"],
        end_value=cfg["training"]["end_lr"],
    )

    opt = tx.multi_transform(
        {
            "adamw": tx.adamw(learning_rate=lr_schedule),
            "none": tx.set_to_zero(),
        },
        label_fn,
    )

    opt = tx.chain(
        tx.clip_by_global_norm(cfg["training"]["clip_grad_norm"]),
        opt,
    )

    # create training state
    our_train_state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables,
        tx=opt,
    )

    # restore training state
    abstract_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, our_train_state)
    checkpointer = ocp.StandardCheckpointer()

    model_dir = os.path.join(cfg["training"]["model_save_dir"], cfg["evaluation"]["run_id"], "best_model")
    return checkpointer.restore(
        model_dir,
        args=ocp.args.StandardRestore(abstract_tree),
    )


def evaluate_prediction_latency(state, cfg):
    @partial(jax.jit, static_argnums=(1))
    def generate_predictions(state, cfg):
        state.apply_fn(
            state.params,
            jnp.ones(cfg["architecture"]["init_shapes"]["input"]),
            jnp.ones(cfg["architecture"]["init_shapes"]["image"]),
        )

    start = time.time()
    for _ in range(1000):
        generate_predictions(state, cfg)
    end = time.time()
    latency = (end - start) / 1000

    return latency


def evaluate_demo_feedforward_multi(state, cfg, write_to_file=False):
    """
    Evaluate frechet distance against expert demonstrations, based on predicting position coordinates directly.
    """
    ACTION_CHUNKS = cfg["architecture"]["inference_params"]["action_chunks"]
    PREDICTION_HORIZON = cfg["architecture"]["inference_params"]["prediction_horizon"]
    NUM_PREDICTED_ACTIONS = cfg["architecture"]["inference_params"]["num_predicted_actions"]
    ENSEMBLE_WEIGHT = cfg["architecture"]["inference_params"]["ensemble_weight"]
    NOISE_SCALE = cfg["evaluation"]["noise_scale"]

    @partial(jax.jit, static_argnums=(1, 2, 3))
    def generate_predictions(state, cfg, action_chunks=ACTION_CHUNKS, prediction_horizon=PREDICTION_HORIZON):
        def compute_step(carry, x):
            env_state, img, horizon, is_first, key = carry
            key, subkey = jax.random.split(key)
            prediction = state.apply_fn(state.params, jnp.expand_dims(env_state, axis=0), jnp.expand_dims(img, axis=0))

            # add random noise to prediction
            prediction += jax.random.normal(subkey, prediction.shape) * NOISE_SCALE

            # perform temporal ensembling
            memory, rear, front, num_elements = deque_append(
                horizon["memory"],
                horizon["rear"],
                horizon["front"],
                horizon["n_elements"],
                jnp.squeeze(prediction),
                is_first,
            )

            updated_horizon = {
                "memory": memory,
                "front": front,
                "rear": rear,
                "n_elements": num_elements,
            }

            ensemble_weights = jnp.asarray(
                [
                    jnp.repeat(jnp.exp(-ENSEMBLE_WEIGHT * i), (ACTION_CHUNKS * 2), axis=-1)
                    for i in range(PREDICTION_HORIZON)
                ]
            )
            memory_filtered = jnp.asarray(
                [
                    memory[i, i * (NUM_PREDICTED_ACTIONS) : i * (NUM_PREDICTED_ACTIONS) + (NUM_PREDICTED_ACTIONS)]
                    for i in range(PREDICTION_HORIZON)
                ]
            )
            ensemble_weights_filtered = jnp.asarray(
                [
                    ensemble_weights[
                        i, i * (NUM_PREDICTED_ACTIONS) : i * (NUM_PREDICTED_ACTIONS) + (NUM_PREDICTED_ACTIONS)
                    ]
                    for i in range(PREDICTION_HORIZON)
                ]
            )
            output_prediction = jnp.sum(memory_filtered * ensemble_weights_filtered, axis=0) / jnp.sum(
                ensemble_weights_filtered, axis=0
            )

            return (output_prediction[-2:], img, updated_horizon, False, key), output_prediction[-2:]

        predicted_vals = []
        demo_vals = []
        for dataset in cfg["dataset"]["shapes"]:
            raw_data = lasa.DataSet.__getattr__(dataset)
            demos = raw_data.demos
            pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
            pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
            img = generate_character_image(dataset[0])
            env_state = jnp.array(pos[:, :, 0])
            key = jax.random.PRNGKey(0)
            keys = jax.random.split(key, pos.shape[0])

            dummy = jnp.zeros((pos.shape[0], ACTION_CHUNKS * 2))
            horizon = {
                "memory": jnp.zeros((jnp.shape(dummy)[0], PREDICTION_HORIZON, *jnp.shape(dummy)[1:])),
                "front": jnp.zeros(jnp.shape(dummy)[0]),
                "rear": jnp.ones(jnp.shape(dummy)[0]) * -1,
                "n_elements": jnp.zeros(jnp.shape(dummy)[0]),
            }

            carry, env_states = jax.vmap(jax.lax.scan, in_axes=(None, (0, None, 0, None, 0), None, None))(
                compute_step, (env_state, img, horizon, True, keys), None, 990
            )
            predicted_vals.append(env_states)
            demo_vals.append(pos)

        predicted_vals = jnp.asarray(predicted_vals)
        demo_vals = jnp.asarray(demo_vals)
        predicted_vals = e.rearrange(
            predicted_vals, "dataset demonstrations timesteps values -> (dataset demonstrations) timesteps values"
        )
        demo_vals = e.rearrange(
            demo_vals, "dataset demonstrations timesteps values -> (dataset demonstrations) timesteps values"
        )

        return predicted_vals, demo_vals

    predicted_vals, demo_vals = generate_predictions(state, cfg)

    fr_dists = []
    num_predictions = []
    for i in range(70):
        sample_predicted, sample_demo = sample_within_threshold(predicted_vals[i, :, -2:].T, demo_vals[i, :, :1000])
        frechet_dist = frdist(sample_predicted.T, sample_demo.T)
        fr_dists.append(frechet_dist)
        num_predictions.append(sample_predicted.shape[1])

    latency = evaluate_prediction_latency(state, cfg) * np.mean(num_predictions)
    demo_vals = np.asarray(demo_vals)
    predicted_vals = np.asarray(predicted_vals)
    fr_dists = np.asarray(fr_dists)
    latency = np.asarray(latency)

    if write_to_file:
        os.makedirs(
            f"{os.path.dirname(os.path.abspath(__file__))}/../.results/{cfg['wandb']['experiment_name']}", exist_ok=True
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}/../.results/{cfg['wandb']['experiment_name']}/demo_vals.npy",
            demo_vals,
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}\
                /../.results/{cfg['wandb']['experiment_name']}/predicted_vals.npy",
            predicted_vals,
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}\
                /../.results/{cfg['wandb']['experiment_name']}/fr_dists.npy",
            fr_dists,
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}\
                /../.results/{cfg['wandb']['experiment_name']}/latency.npy",
            latency,
        )

    return demo_vals, predicted_vals, fr_dists, latency


def output_representations_feedforward_multi(state, cfg, write_to_file=False):
    """
    Evaluate frechet distance against expert demonstrations, based on predicting position coordinates directly.
    """
    ACTION_CHUNKS = cfg["architecture"]["inference_params"]["action_chunks"]
    PREDICTION_HORIZON = cfg["architecture"]["inference_params"]["prediction_horizon"]
    NUM_PREDICTED_ACTIONS = cfg["architecture"]["inference_params"]["num_predicted_actions"]
    ENSEMBLE_WEIGHT = cfg["architecture"]["inference_params"]["ensemble_weight"]
    NOISE_SCALE = cfg["evaluation"]["noise_scale"]

    @partial(jax.jit, static_argnums=(1, 2, 3))
    def generate_predictions(state, cfg, action_chunks=ACTION_CHUNKS, prediction_horizon=PREDICTION_HORIZON):
        def compute_step(carry, x):
            env_state, img, horizon, is_first, key = carry
            key, subkey = jax.random.split(key)
            prediction, intermediates = state.apply_fn(
                state.params,
                jnp.expand_dims(env_state, axis=0),
                jnp.expand_dims(img, axis=0),
                capture_intermediates=True,
                mutable=["intermediates"],
            )

            # add random noise to prediction
            prediction += jax.random.normal(subkey, prediction.shape) * NOISE_SCALE

            # perform temporal ensembling
            memory, rear, front, num_elements = deque_append(
                horizon["memory"],
                horizon["rear"],
                horizon["front"],
                horizon["n_elements"],
                jnp.squeeze(prediction),
                is_first,
            )

            updated_horizon = {
                "memory": memory,
                "front": front,
                "rear": rear,
                "n_elements": num_elements,
            }

            weights = jnp.asarray(
                [
                    jnp.repeat(jnp.exp(-ENSEMBLE_WEIGHT * i), (ACTION_CHUNKS * 2), axis=-1)
                    for i in range(PREDICTION_HORIZON)
                ]
            )
            current_prediction = jnp.sum(memory * weights, axis=0) / jnp.sum(weights, axis=0)
            output_prediction = current_prediction[NUM_PREDICTED_ACTIONS - 2 : NUM_PREDICTED_ACTIONS]

            return (output_prediction, img, updated_horizon, False, key), intermediates["intermediates"]["Dense_12"][
                "__call__"
            ][0]

        embeddings = []
        for dataset in cfg["dataset"]["shapes"]:
            raw_data = lasa.DataSet.__getattr__(dataset)
            demos = raw_data.demos
            pos = jnp.hstack([demos[i].pos for i in range(len(demos))])
            pos = e.rearrange(pos, "points (demonstration timesteps) -> demonstration points timesteps", timesteps=1000)
            img = generate_character_image(dataset[0])
            env_state = jnp.array(pos[:, :, 0])
            key = jax.random.PRNGKey(0)
            keys = jax.random.split(key, pos.shape[0])

            dummy = jnp.zeros((pos.shape[0], ACTION_CHUNKS * 2))
            horizon = {
                "memory": jnp.zeros((jnp.shape(dummy)[0], PREDICTION_HORIZON, *jnp.shape(dummy)[1:])),
                "front": jnp.zeros(jnp.shape(dummy)[0]),
                "rear": jnp.ones(jnp.shape(dummy)[0]) * -1,
                "n_elements": jnp.zeros(jnp.shape(dummy)[0]),
            }

            carry, embedding = jax.vmap(jax.lax.scan, in_axes=(None, (0, None, 0, None, 0), None, None))(
                compute_step, (env_state, img, horizon, True, keys), None, 990
            )
            embeddings.append(embedding)

        embeddings = jnp.squeeze(jnp.asarray(embeddings))
        embeddings = e.rearrange(
            embeddings, "dataset demonstrations timesteps values -> (dataset demonstrations) timesteps values"
        )

        return embeddings

    embeddings = np.asarray(generate_predictions(state, cfg))

    if write_to_file:
        os.makedirs(
            f"{os.path.dirname(os.path.abspath(__file__))}/../.results/{cfg['wandb']['experiment_name']}", exist_ok=True
        )
        np.save(
            f"{os.path.dirname(os.path.abspath(__file__))}\
                /../.results/{cfg['wandb']['experiment_name']}/embeddings.npy",
            embeddings,
        )

    return embeddings


@hydra.main(version_base=None, config_path="..")
def main(cfg: DictConfig) -> None:
    cfg = cfg["config"]
    state = load_feedforward_multi(cfg)
    evaluate_demo_feedforward_multi(state, cfg)
    output_representations_feedforward_multi(state, cfg)


if __name__ == "__main__":
    main()
