import os
import torch
import hydra
# import runners

from omegaconf import DictConfig, OmegaConf
from utils import get_logger, set_seed_everywhere
import runners


@hydra.main(version_base=None, config_path="configs", config_name="main") 
def main(config: DictConfig):
    set_seed_everywhere(config.seed) # set random seed for reproducibility

    if config.cleardir == True:
        # clear the run directory if specified
        if os.path.exists('run'):
            print("Clearing the 'run' directory.")
            for root, dirs, files in os.walk('run', topdown=False):
                for name in files:
                    os.remove(os.path.join(root, name))
                for name in dirs:
                    os.rmdir(os.path.join(root, name))
        else:
            print("run' directory does not exist. No need to clear.")

    # get workdir
    config.workdir = os.path.join('run', config.problem.manifold,
                                   config.problem.dataset,
                                   config.save_prefix + f"-{config.seed}-" + config.now + f"-{config.sample.sampler}")
    # get logger
    log_dir = os.path.join(config.workdir, 'logs')
    logging = get_logger(log_dir)

    # add device
    config.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    logging.info(f"Found {os.cpu_count()} total number of CPUs.")
    if config.device == 'cuda':
        logging.info(f"Found {torch.cuda.device_count()} CUDA devices.")
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            logging.info(f"{props.name} with Memory: {props.total_memory / (1024 ** 3):.2f}GB")
    logging.info(f"Using device: {config.device}")

    # convert config to yaml string and save it to log directory
    yaml_str = OmegaConf.to_yaml(config) # convert config to yaml string
    with open(os.path.join(log_dir, 'config.yml'), 'w') as file:
        file.write(yaml_str)
    logging.info(f"Writing log file to {log_dir}")
    logging.info(">" * 80)
    logging.info(yaml_str)
    logging.info("<" * 80) 
    '---------------------------run----------------------------------'
    runner = getattr(runners, config.problem.runner)(config)
    runner.run()  # run the training or evaluation process


if __name__ == '__main__':
    main()  # entry point for the script, starts the hydra application