from pathlib import Path

from hydra import initialize, compose
import pytest

from maml.evaluate import evaluate
from maml.meta_train import maml_training_and_evaluation


default_omniglot_path = Path("omniglot_resized/all_images")


@pytest.mark.parametrize("implicit", [True, False])
def test_evaluate(implicit):
    with initialize(config_path="../../conf"):
        cfg = compose(config_name="config", overrides=[
            # data
            "train.data.meta_batch_size=3",
            "train.data.inner_batch_size_train=4",
            "train.data.inner_batch_size_eval=5",
            # model
            "train.model._target_=maml.model.MetaMLP",
            "+train.model.implicit=" + str(implicit),
            "+eval.model.implicit=" + str(implicit),
            # training
            "train.optim.n_outer_steps=1",
            "train.save_path_params=test.pickle",
            # eval
            "eval.save_path_fig=test.png",
            "eval.data.meta_batch_size=1",
            "eval.data.inner_batch_size_train=1",
            "eval.data.inner_batch_size_eval=1",
        ])
        maml_training_and_evaluation(cfg)
    with initialize(config_path="../../conf"):
        cfg = compose(config_name="config", overrides=[
            # model
            "train.model.maxiter=2",
            "train.model._target_=maml.model.MetaMLP",
            "+train.model.implicit=" + str(implicit),
            "+eval.model.implicit=" + str(implicit),
            # training
            "train.save_path_params=test.pickle",
            # eval
            "eval.save_path_fig=test_only.png",
            "eval.save_path_csv=test_only.csv",
            "eval.data.meta_batch_size=1",
            "eval.data.inner_batch_size_train=2",
            "eval.data.inner_batch_size_eval=10",
        ])
        evaluate(cfg)
    # clean up
    for filename in [
        "test.pickle",
        "test.png",
        "maml_learning_curve.jpg",
        "test_only.png",
        "test_only.csv",
    ]:
        (Path.cwd() / filename).unlink()


@pytest.mark.skipif(not default_omniglot_path.exists(), reason="Omniglot dataset not found")
@pytest.mark.parametrize("implicit", [True, False])
def test_evaluate_cnn(implicit):
    with initialize(config_path="../../conf"):
        cfg = compose(config_name="cnn", overrides=[
            # data
            "train.data.meta_batch_size=3",
            # model
            "train.model._target_=maml.model.MetaConvNet",
            "+train.model.implicit=" + str(implicit),
            "+eval.model.implicit=" + str(implicit),
            # training
            "train.optim.n_outer_steps=1",
            "train.save_path_params=test.pickle",
            # eval
            "eval.data.meta_batch_size=1",
        ])
        maml_training_and_evaluation(cfg)
    with initialize(config_path="../../conf"):
        cfg = compose(config_name="cnn", overrides=[
            # model
            "train.model.maxiter=2",
            "train.model._target_=maml.model.MetaConvNet",
            "+train.model.implicit=" + str(implicit),
            "+eval.model.implicit=" + str(implicit),
            # training
            "train.save_path_params=test.pickle",
            # eval
            "eval.save_path_csv=test_only.csv",
            "eval.data.meta_batch_size=1",
        ])
        evaluate(cfg)
    # clean up
    for filename in [
        "test.pickle",
        "maml_learning_curve.jpg",
        "test_only.csv",
    ]:
        (Path.cwd() / filename).unlink()
