from pathlib import Path

from hydra import initialize, compose
import pytest

from maml.model import MetaMLP, MetaConvNet
from maml.meta_train import meta_train, maml_training_and_evaluation


default_omniglot_path = Path("omniglot_resized/all_images")


@pytest.mark.parametrize("model, datasource", [
    (MetaMLP, "sinusoid"),
    (MetaConvNet, "omniglot"),
])
@pytest.mark.parametrize("implicit", [True, False])
@pytest.mark.parametrize("learn_reg", [True, False])
def test_meta_train_call(model, datasource, implicit, learn_reg):
    if datasource == "omniglot" and not default_omniglot_path.exists():
        pytest.skip("Omniglot dataset not found")
    if not implicit and learn_reg:
        pytest.skip("Regularization learning is not supported for unrolled models")
    meta_model = model(implicit=implicit, learn_reg=learn_reg, per_param_reg=learn_reg)
    meta_train(meta_model, n_outer_steps=1, meta_batch_size=3, datasource=datasource)


@pytest.mark.parametrize("per_param_reg", [True, False])
def test_meta_train_call_stop_reg_learn(per_param_reg):
    model = MetaMLP
    datasource = "sinusoid"
    implicit = True
    learn_reg = True
    if datasource == "omniglot" and not default_omniglot_path.exists():
        pytest.skip("Omniglot dataset not found")
    if not implicit and learn_reg:
        pytest.skip("Regularization learning is not supported for unrolled models")
    meta_model = model(implicit=implicit, learn_reg=learn_reg, per_param_reg=per_param_reg)
    meta_train(meta_model, n_outer_steps=2, meta_batch_size=3, datasource=datasource, stop_reg_learning_after_steps=0)


@pytest.mark.parametrize("implicit", [True, False])
def test_meta_training_and_evaluation(implicit):
    with initialize(config_path="../../conf"):
        cfg = compose(config_name="config", overrides=[
            # data
            "train.data.meta_batch_size=3",
            "train.data.inner_batch_size_train=4",
            "train.data.inner_batch_size_eval=5",
            # model
            "train.model._target_=maml.model.MetaMLP",
            "+train.model.implicit=" + str(implicit),
            "+eval.model.implicit=" + str(implicit),
            # training
            "train.optim.n_outer_steps=1",
            "train.save_path_params=test.pickle",
            # eval
            "eval.save_path_fig=test.png",
        ])
        maml_training_and_evaluation(cfg)
    # clean up
    for filename in [
        "test.pickle",
        "test.png",
        "maml_learning_curve.jpg",
    ]:
        (Path.cwd() / filename).unlink()


@pytest.mark.skipif(not default_omniglot_path.exists(), reason="Omniglot dataset not found")
@pytest.mark.parametrize("implicit", [True, False])
def test_meta_training_and_evaluation_cnn(implicit):
    with initialize(config_path="../../conf"):
        cfg = compose(config_name="cnn", overrides=[
            # data
            "train.data.meta_batch_size=3",
            # model
            "train.model._target_=maml.model.MetaConvNet",
            "+train.model.implicit=" + str(implicit),
            "+eval.model.implicit=" + str(implicit),
            # training
            "train.optim.n_outer_steps=1",
            "train.save_path_params=test.pickle",
            # eval
            "eval.data.meta_batch_size=1",
            "eval.save_path_fig=test.png",
        ])
        maml_training_and_evaluation(cfg)
    # clean up
    for filename in [
        "test.pickle",
        "maml_learning_curve.jpg",
    ]:
        (Path.cwd() / filename).unlink()
