import argparse
import jax
from jax_smi import initialise_tracking
from typing import Dict
import yaml

from poppy.trainers import TrainingConfig
from poppy.utils.logger import EnsembleLogger, NeptuneLogger, TerminalLogger
from poppy.utils.config_utils import make_env_trainer, EnvironmentConfig


def read_config_file(filename: str) -> Dict:
    with open(filename, 'r') as f:
        return yaml.load(f, yaml.Loader)


def make_env_config(config: Dict) -> EnvironmentConfig:
    return EnvironmentConfig(
        name=config["environment"]["name"],
        params=config["environment"]["params"]
    )


def create_logger(args):
    loggers = [TerminalLogger(label="", time_delta=10)]
    if not args.neptune_disable:
        assert args.neptune_name is not None
        loggers.append(NeptuneLogger(
            project="Poppy",
            name=args.neptune_name,
        ))
    return EnsembleLogger(loggers)


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("config", help="path to a YAML file containing the configuration parameters")
    parser.add_argument("--neptune_disable", default=False, action="store_true", help="whether disable logs to Neptune")
    parser.add_argument("--neptune_name", type=str, default=None, help="name for the Neptune experiment")
    args = parser.parse_args()

    # Determine the number of devices
    num_devices = len(jax.local_devices())
    print(f"Running on {num_devices} devices")

    # Initialize the logger
    logger = create_logger(args)

    # Read the configuration and upload it
    config = read_config_file(args.config)
    logger.write_artifact({"config.yaml": args.config})
    logger.write_config(config)  # Uploading the dict is useful to allow filtering fields in Neptune :)

    # Parse the environment configuration
    env_config = make_env_config(config)
    environment, trainer_class = make_env_trainer(env_config)

    training_config = TrainingConfig(
        learning_rate_encoder=1e-4 if config["train_encoder"] else 0.0,
        learning_rate_decoder=1e-4,
        batch_size=config["num_problems_training"] // num_devices,  # batch size *per device* (repeated for each element of the population)
        minibatch_train=config["minibatch_size_training"] // num_devices,
        pop_size=config["pop_size"],
        pomo_size=config["num_starting_points"],
        training_method=config["training_method"],
        loss_annealing_steps=config["loss_annealing_steps"],
        l2_regularization=1e-6,
        seed=config["seed"],  # it used to be 0 by default
        validation_freq=config["validation_freq"],
        num_validation_problems=config["num_problems_validation"] // num_devices,  # num_validation_problems *per device*
        minibatch_validation=config["minibatch_size_validation"] // num_devices,
        num_devices=-1,  # -1 --> auto-detect
        save_checkpoint=config["save_checkpoint"],
        save_best=config["save_best"],
        load_checkpoint=config["load_checkpoint"],  # '' --> no checkpoint
        load_decoder=config["load_decoder"],
        load_optimizer=config["load_optimizer"],
        compute_expensive_metrics=False,
        use_augmentation_validation=config["use_augmentation_validation"],
        save_matrix_freq=-1,
    )

    trainer = trainer_class(
        environment=environment,
        config=training_config,
        logger=logger,
    )

    initialise_tracking()
    trainer.train(num_steps=config["num_steps"])
    logger.close()
