import os

import torch
from loguru import logger
from project_utils.logging_utils import init_logger_and_wandb, load_args_from_yaml
from project_utils.str_repo import debug_string


def check_previous_runs(cfg, train_data, test_data, val_data):
    best_model_info = {
        "epoch": None,
        "model_state": None,
        "optimizer_state": None,
    }
    if not cfg.debug:
        if not cfg.replace:
            if os.path.exists(os.path.join(cfg.model_dir, "model.pt")):
                # We have already trained the model
                logger.info("Model already trained")
                try:
                    cfg = load_args_from_yaml(os.path.join(cfg.model_dir, "cfg.yaml"))
                except FileNotFoundError:
                    logger.info("Could not load cfg.yaml")
                try:
                    best_model_info = torch.load(
                        os.path.join(cfg.model_dir, "model.pt")
                    )
                    cfg.no_train = True
                except FileNotFoundError:
                    logger.info("Could not load model.pt")
                    cfg.no_train = False

            if os.path.exists(os.path.join(cfg.model_outputs_dir, "test_outputs.npz")):
                # We have already stored the outputs
                logger.info("Model outputs already stored")
                cfg.no_test = True
        else:
            logger.info("Replacing existing model")
        out_train_data = train_data
        out_test_data = test_data
        out_val_data = val_data
    else:
        logger.info(debug_string)
        out_train_data = train_data[:3000]
        out_test_data = test_data[:100]
        out_val_data = val_data[:500]

    if cfg.no_train and cfg.no_test:
        logger.info("No training or testing required")
    if cfg.only_test_train_on_test_set:
        assert cfg.no_train, "Cannot do only test when no train is false"
    return best_model_info, out_train_data, out_test_data, out_val_data


def test_assertions(cfg, best_model_info):
    assert (
        best_model_info["epoch"] is None
        if cfg.only_test_train_on_test_set
        else isinstance(best_model_info["epoch"], int)
    ), "We should not have a trained model when only-test-train-on-test-set is true, else num epochs should be an integer"
    assert (
        cfg.only_test_train_on_test_set == cfg.train_on_test_set
        if cfg.only_test_train_on_test_set
        else True
    ), "only-test-train-on-test-set should be true only if train-on-test-set is true"
