# implementation of MAML using JAX and Jaxopt
import jax
import jax.numpy as jnp
from jax.tree_util import tree_map
from jaxopt import OptaxSolver
import hydra
from hydra.utils import to_absolute_path, instantiate
import matplotlib.pyplot as plt
import optax
from optax import softmax_cross_entropy
from tqdm.auto import tqdm

from maml.data import sample_task_batch, OmniglotDataset
from maml.evaluate import eval_maml_meta_params, plot_predictions
from maml.model import save_weights


def meta_train(
        meta_model,
        n_outer_steps=10,
        stop_reg_learning_after_steps=None,
        meta_lr=1e-3,
        reg_lr=1e-3,
        debug_optim=False,
        seed=0,
        clip_params_norm=1.,
        cosine_schedule=False,
        clip_reg_norm=1.,
        datasource="sinusoid",
        test_loss="mse",
        meta_batch_size=25,
        inner_batch_size_train=10,
        inner_batch_size_eval=10,
        n_unique_meta_batches=None,
        **omniglot_kwargs,
    ):
    # data
    # TODO: get concerns out
    if datasource == "sinusoid":
        get_sample = sample_task_batch
    elif datasource == "omniglot":
        dataset = OmniglotDataset(
            **omniglot_kwargs,
            seed=seed,
            inner_batch_size_train=inner_batch_size_train,
            inner_batch_size_eval=inner_batch_size_eval,
        )
        dataset.load_characters_labels()
        def get_sample(key, meta_batch_size, *_args, **_kwargs):
            key, new_key = jax.random.split(key)
            (
                tasks_inputs_train,
                tasks_outputs_train,
                tasks_inputs_eval,
                tasks_outputs_eval,
            ) = dataset.sample_tasks(key, meta_batch_size)
            return (
                (tasks_inputs_train, tasks_outputs_train),
                (tasks_inputs_eval, tasks_outputs_eval),
                new_key,
            )


    meta_params_and_reg = meta_model.initialize_params_and_reg()

    def full_loss(meta_params, training_task_batch_train, training_task_batch_test):
        task_adapted_params = meta_model(training_task_batch_train, meta_params)
        tasks_inputs, tasks_outputs = training_task_batch_test
        predictions = meta_model.batch_predict(task_adapted_params, tasks_inputs)
        if test_loss == "mse":
            losses = (predictions - tasks_outputs) ** 2
            accuracy = None
        elif test_loss == "cross_entropy":
            losses = softmax_cross_entropy(logits=predictions, labels=tasks_outputs)
            accuracies = jnp.argmax(predictions, axis=-1) == jnp.argmax(tasks_outputs, axis=-1)
            accuracy = jnp.mean(accuracies)
        loss = jnp.mean(losses)
        return loss, accuracy

    # outer-solver
    reg_learning_rate_schedule = optax.piecewise_constant_schedule(
        reg_lr,
        boundaries_and_scales={stop_reg_learning_after_steps: 0.} if stop_reg_learning_after_steps is not None else None,
    )
    if cosine_schedule:
        meta_params_learning_rate_schedule = optax.cosine_decay_schedule(
            meta_lr,
            decay_steps=n_outer_steps,
            alpha=1e-7,
        )
    else:
        meta_params_learning_rate_schedule = meta_lr
    multi_opt = optax.multi_transform(
        {
            "meta_params": optax.chain(
                optax.adam(meta_params_learning_rate_schedule),
                optax.clip_by_global_norm(clip_params_norm),
            ),
            "reg": optax.chain(
                optax.adam(reg_learning_rate_schedule),
                optax.clip_by_global_norm(clip_reg_norm),
            ),
        },
        {"meta_params": "meta_params", "reg": "reg"},
    )
    solver = OptaxSolver(
        opt=multi_opt,
        fun=full_loss,
        maxiter=n_outer_steps,
        has_aux=True,  # we return accuracy
    )
    key = jax.random.PRNGKey(0)
    training_task_batch_train, training_task_batch_test, _ = get_sample(key, meta_batch_size, inner_batch_size_train, inner_batch_size_eval)
    state = solver.init_state(meta_params_and_reg, training_task_batch_train, training_task_batch_test)
    jitted_update = jax.jit(solver.update)
    gradient_subopt = []

    key = jax.random.PRNGKey(seed)
    pbar = tqdm(range(solver.maxiter))
    for it in pbar:
        # sample tasks
        if n_unique_meta_batches is not None and it % n_unique_meta_batches == 0:
            key = jax.random.PRNGKey(seed)
        training_task_batch_train, training_task_batch_test, key = get_sample(key, meta_batch_size, inner_batch_size_train, inner_batch_size_eval)
        meta_params_and_reg, state = jitted_update(meta_params_and_reg, state, training_task_batch_train, training_task_batch_test)
        if debug_optim:
            # print the norm of params
            print(f"total norm of grad {state.error:.3f}")
            print(f"norm of meta_params: {tree_map(jnp.linalg.norm, meta_params_and_reg['meta_params'])}")
            print(f"reg: {meta_params_and_reg['reg']}")
        accuracy = state.aux if state.aux is not None else 0
        pbar.set_description(f"outer loss (val loss): {state.value:.3f}, val accuracy: {accuracy:.3f}")
        gradient_subopt.append(state.error)

    return meta_params_and_reg, gradient_subopt


def plot_gradient_suboptimality(gradient_subopt, save_path):
    plt.plot(gradient_subopt)
    plt.xlabel("outer step")
    plt.ylabel("gradient suboptimality")
    plt.title("MAML learning curve")
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()


@hydra.main(config_path="../conf", config_name="config")
def maml_training_and_evaluation(cfg):
    meta_model = instantiate(cfg.train.model)
    meta_params_and_reg, gradient_subopt = meta_train(meta_model, **cfg.train.data, **cfg.train.optim)
    plot_gradient_suboptimality(gradient_subopt, save_path=cfg.train.save_path_fig)
    save_weights(meta_params_and_reg, to_absolute_path(cfg.train.save_path_params))
    # now let's evaluate the perf of these meta-params
    eval_meta_model = instantiate(cfg.eval.model)
    eval_results = eval_maml_meta_params(
        eval_meta_model,
        meta_params_and_reg,
        test_loss=cfg.eval.test_loss,
        **cfg.eval.data,
    )
    if cfg.eval.data.datasource == "sinusoid":
        plot_predictions(eval_results, save_path=cfg.eval.save_path_fig)
    return eval_results


if __name__ == "__main__":
    maml_training_and_evaluation()
