import pytest
import numpy as np
import random as rand
from pathlib import Path
import kdai._logging
import kdai.train
import kdtpp.experiments as exp
import logging
import kdtpp.mea
import json


_logger = logging.getLogger(__name__)


@pytest.fixture
def seed_random():
    rand.seed(123)
    np.random.seed(123)


@pytest.fixture
def np_rng():
    return np.random.default_rng(123)


@pytest.fixture
def resource_dir():
    return Path(__file__).parent / "resources"


@pytest.fixture
def trainable_factory():
    OUT_DIR = Path("./out/test/trainable_factory")

    def _create(
        exp_class, model_name, ds_name, train_len=None, load_ckpt=False
    ):
        if train_len is None:
            # Just use 1 length, the largest one.
            train_len = exp_class.default_train_lens(ds_name)[0]
        run_itr = exp_class.for_model_and_ds(model_name, ds_name, [train_len])
        runs = list(run_itr)
        run_spec, ds_mgr_fn = runs[0]
        model_dir = OUT_DIR / run_spec.model_dir
        trainable = exp.trainable_fns[model_name](
            ds_mgr_fn, eval_mode="train-loss"
        )
        if load_ckpt:
            if not model_dir.exists():
                raise FileNotFoundError(
                    f"Requested load_ckpt={load_ckpt}, but "
                    "model dir not found: {model_dir}"
                )
            _logger.info(f"Model dir found: {model_dir}")
            checkpoint_path = model_dir / "checkpoint_best_loss.pth"
            if not checkpoint_path.exists():
                raise FileNotFoundError(
                    f"Model dir found but no checkpoint: {checkpoint_path}"
                )
            _logger.info(f"Model checkpoint: {checkpoint_path}")
            kdai._logging.load_model(trainable.model, model_dir)
        return trainable

    return _create
    # else:
    #     _logger.info(f"Model dir not found ({model_dir})")
    #     _logger.info(f"Training {model_name} on {ds_name}")
    #     kwargs = exp.get_train_args(ds_name, model_name)
    #     # Just do 1 epoch
    #     kwargs["n_epochs"] = 1 # run_spec.n_epochs
    #     kwargs["batch_size"] = run_spec.batch_size
    #     kwargs["steps_til_eval"] = run_spec.steps_til_eval
    #     kdai.train.train(trainable, out_dir=model_dir, **kwargs)


@pytest.fixture
def rec(resource_dir):
    with open(resource_dir / "chicken_2021_08_17.json", "r") as f:
        ds_details = json.load(f)
    # gid_map = ds_details["recording_cell_ids"] # not needed.
    assert len(ds_details["recordings"]) == 1
    rec = kdtpp.mea.CompressedSpikeRecording.from_json(
        ds_details["recordings"][0]
    )
    return rec
