import sys
if sys.version_info.major == 3 and sys.version_info.minor >= 10:

    from collections.abc import MutableMapping
else:
    from collections import MutableMapping
from dataclasses import dataclass
from pathlib import Path

import jax
import jax.numpy as jnp
import hydra
from hydra.utils import to_absolute_path, instantiate
import matplotlib.pyplot as plt
from optax import softmax_cross_entropy
import pandas as pd

from maml.data import sample_task_batch, OmniglotDataset
from maml.model import load_weights


@dataclass
class EvalResults:
    predictions: jnp.array
    mean_predictions: jnp.array
    tasks_inputs: jnp.array
    tasks_outputs: jnp.array
    eval_task_batch_train: jnp.array
    loss: float=0.0
    loss_std: float=0.0
    inner_loss: float=0.0
    inner_loss_std: float=0.0
    accuracy: float=None
    accuracy_std: float=None
    mean_pred_loss: float=0.0
    mean_pred_loss_std: float=0.0
    mean_pred_accuracy: float=None
    mean_pred_accuracy_std: float=None


def eval_maml_meta_params(
        meta_model,
        meta_params_and_reg,
        seed=0,
        datasource="sinusoid",
        test_loss="mse",
        meta_batch_size=25,
        inner_batch_size_train=10,
        inner_batch_size_eval=10,
        linspace_eval=False,
        n_batches=1,
        **omniglot_kwargs,
    ):
    # data
    if datasource == "omniglot":
        dataset = OmniglotDataset(
            **omniglot_kwargs,
            seed=seed,
            inner_batch_size_train=inner_batch_size_train,
            inner_batch_size_eval=inner_batch_size_eval,
        )
        dataset.load_characters_labels(train=False)
    losses = []
    inner_losses = []
    mean_pred_losses = []
    if test_loss == "cross_entropy":
        accuracies = []
        mean_pred_accuracies = []
    key = jax.random.PRNGKey(seed)
    for i in range(n_batches):
        if datasource == "sinusoid":
            eval_task_batch_train, eval_task_batch_test, key = sample_task_batch(
                key,
                meta_batch_size=meta_batch_size,
                inner_batch_size_train=inner_batch_size_train,
                inner_batch_size_eval=inner_batch_size_eval,
                linspace_eval=linspace_eval,
            )
        elif datasource == "omniglot":
            key, new_key = jax.random.split(key)
            (
                    tasks_inputs_train,
                    tasks_outputs_train,
                    tasks_inputs_eval,
                    tasks_outputs_eval,
            ) = dataset.sample_tasks(
                key,
                meta_batch_size,
                train=False,
            )
            key = new_key
            eval_task_batch_train = (tasks_inputs_train, tasks_outputs_train)
            eval_task_batch_test = (tasks_inputs_eval, tasks_outputs_eval)
        mean_predictions = meta_model.batch_predict(
            meta_model.duplicate_params(meta_params_and_reg["meta_params"], n_tasks=meta_batch_size),
            eval_task_batch_test[0],
        )
        task_adapted_params = meta_model(eval_task_batch_train, meta_params_and_reg)
        inner_loss = jax.vmap(
            meta_model.loss,
            (0, None, 0, 0, None),
        )(
            task_adapted_params,
            meta_params_and_reg["meta_params"],
            eval_task_batch_train[0],
            eval_task_batch_train[1],
            meta_params_and_reg["reg"],
        )
        inner_losses.append(inner_loss)
        predictions = meta_model.batch_predict(task_adapted_params, eval_task_batch_test[0])
        tasks_inputs, tasks_outputs = eval_task_batch_test
        if test_loss == "mse":
            losses.append((predictions - tasks_outputs) ** 2)
            mean_pred_losses.append((mean_predictions - tasks_outputs) ** 2)
        elif test_loss == "cross_entropy":
            losses.append(softmax_cross_entropy(logits=predictions, labels=tasks_outputs))
            mean_pred_losses.append(softmax_cross_entropy(logits=mean_predictions, labels=tasks_outputs))
            accuracies.append(jnp.argmax(predictions, axis=-1) == jnp.argmax(tasks_outputs, axis=-1))
            mean_pred_accuracies.append(jnp.argmax(mean_predictions, axis=-1) == jnp.argmax(tasks_outputs, axis=-1))
    if test_loss == "cross_entropy":
        accuracies = jnp.concatenate(accuracies, axis=0)
        mean_pred_accuracies = jnp.concatenate(mean_pred_accuracies, axis=0)
        accuracy = jnp.mean(accuracies)
        accuracy_std = jnp.std(accuracies)
        mean_pred_accuracy = jnp.mean(mean_pred_accuracies)
        mean_pred_accuracy_std = jnp.std(mean_pred_accuracies)
        print(f"accuracy={accuracy}, mean_pred_accuracy={mean_pred_accuracy}")
    losses = jnp.concatenate(losses, axis=0)
    inner_losses = jnp.concatenate(inner_losses, axis=0)
    mean_pred_losses = jnp.concatenate(mean_pred_losses, axis=0)
    loss = jnp.mean(losses)
    loss_std = jnp.std(losses)
    inner_loss = jnp.mean(inner_losses)
    inner_loss_std = jnp.std(inner_losses)
    mean_pred_loss = jnp.mean(mean_pred_losses)
    mean_pred_loss_std = jnp.std(mean_pred_losses)
    print(f"loss={loss}, mean_pred_loss={mean_pred_loss}")
    result = EvalResults(
        loss=loss,
        loss_std=loss_std,
        inner_loss=inner_loss,
        inner_loss_std=inner_loss_std,
        predictions=predictions,
        mean_pred_loss=mean_pred_loss,
        mean_pred_loss_std=mean_pred_loss_std,
        mean_predictions=mean_predictions,
        tasks_inputs=tasks_inputs,
        tasks_outputs=tasks_outputs,
        eval_task_batch_train=eval_task_batch_train,
    )
    if test_loss == "cross_entropy":
        result.accuracy = accuracy
        result.accuracy_std = accuracy_std
        result.mean_pred_accuracy = mean_pred_accuracy
        result.mean_pred_accuracy_std = mean_pred_accuracy_std
    return result

def plot_predictions(eval_results, save_path="maml_predictions.png"):
    (
        predictions,
        mean_predictions,
        tasks_inputs,
        tasks_outputs,
        grad_inputs,
    ) = (
        eval_results.predictions,
        eval_results.mean_predictions,
        eval_results.tasks_inputs,
        eval_results.tasks_outputs,
        eval_results.eval_task_batch_train,
    )
    # let's plot what's happening for the first 2 tasks
    n_tasks = 2
    fig, axs = plt.subplots(n_tasks, 1, sharex=True, sharey=True)
    for i in range(n_tasks):
        axs[i].plot(tasks_inputs[i], tasks_outputs[i], "-", marker="o", markersize="2", label="ground truth", color="red")
        axs[i].plot(tasks_inputs[i], predictions[i], "--", marker="x", markersize="2", label="predictions", color="green")
        axs[i].plot(tasks_inputs[i], mean_predictions[i], "--", marker="+", markersize="2", label="mean predictions", color="green", alpha=0.3)
        axs[i].scatter(grad_inputs[0][i], grad_inputs[1][i], marker="^", s=50, label="grad inputs", color="purple", zorder=10)
        axs[i].legend()
        axs[i].set_title(f"task {i}")
        axs[i].set_ylabel("y = sin(x - phase) * amp")
        axs[i].set_xlabel("x")
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")


def flatten_dict(d, parent_key='', sep='.'):
    # from https://stackoverflow.com/a/6027615/4332585
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, MutableMapping):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def save_loss(eval_results, cfg, save_path="mse.csv"):
    metrics_dict = {
        "loss": eval_results.loss,
        "loss_std": eval_results.loss_std,
        "inner_loss": eval_results.inner_loss,
        "inner_loss_std": eval_results.inner_loss_std,
        "mean_pred_loss": eval_results.mean_pred_loss,
        "mean_pred_loss_std": eval_results.mean_pred_loss_std,
    }
    if eval_results.accuracy is not None:
        metrics_dict["accuracy"] = eval_results.accuracy
        metrics_dict["accuracy_std"] = eval_results.accuracy_std
        metrics_dict["mean_pred_accuracy"] = eval_results.mean_pred_accuracy
        metrics_dict["mean_pred_accuracy_std"] = eval_results.mean_pred_accuracy_std
    df = pd.DataFrame([{
        **metrics_dict,
        **flatten_dict(cfg),
    }])
    save_path = Path(to_absolute_path(save_path))
    if save_path.exists():
        df.to_csv(save_path, mode="a", header=False, index=False)
    else:
        df.to_csv(save_path, index=False)
    return df


@hydra.main(config_path="../conf", config_name="config")
def evaluate(cfg):
    meta_params_and_reg = load_weights(to_absolute_path(cfg.train.save_path_params))
    meta_model = instantiate(cfg.eval.model)
    eval_result = eval_maml_meta_params(
        meta_model,
        meta_params_and_reg,
        test_loss=cfg.eval.test_loss,
        **cfg.eval.data,
    )
    if cfg.eval.data.datasource == "sinusoid":
        plot_predictions(eval_result, save_path=cfg.eval.save_path_fig)
    save_loss(eval_result, cfg, save_path=cfg.eval.save_path_csv)
    return eval_result


if __name__ == "__main__":
    evaluate()
