import wandb
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
import argparse
from src import create_logger
from src import Trainer
import os

def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    env_params = cfg.env_params
    model_params = cfg.model_params
    optimizer_params = cfg.optimizer_params
    trainer_params = cfg.trainer_params
    logger_params = cfg.logger_params


    create_logger(logger_params)
    _print_config(env_params, model_params, optimizer_params, trainer_params, logger_params)


    trainer = Trainer(env_params=env_params,
                      model_params=model_params,
                      optimizer_params=optimizer_params,
                      trainer_params=trainer_params,
                      logger_params=logger_params)


    trainer.run()


def _print_config(env_params, model_params, optimizer_params, trainer_params, logger_params):
    logger = logging.getLogger('root')
    logger.info(f'env_params: {env_params}')
    logger.info(f'model_params: {model_params}')
    logger.info(f'optimizer_params: {optimizer_params}')
    logger.info(f'trainer_params: {trainer_params}')
    logger.info(f'logger_params: {logger_params}')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train model with config file")
    parser.add_argument('config', type=str, help="Path to the config file")
    parser.add_argument('overrides', nargs='*', help="Override configuration parameters (key=value)")
    # Add gpu device number
    parser.add_argument('--gpu', type=int, default=0, help="GPU device number")
    # Use wandb
    parser.add_argument('--use_wandb', type=bool, default=True, help="Use wandb")

    args = parser.parse_args()

    # Always make the config path relative to 'configs/train'
    full_config_path = os.path.join("configs/train", args.config)

    # Split the path into directory and filename
    config_path, config_name = os.path.split(full_config_path)

    # Initialize Hydra with the provided configuration file
    hydra.initialize(config_path=config_path, version_base=None)
    cfg = hydra.compose(config_name=config_name, overrides=args.overrides)

    cfg.trainer_params.cuda_device_num = args.gpu

    main(cfg)