import os
import random

import fire
import numpy as np
import psutil
import torch

import sys

import utils
from modules.config.config import Config

from pathlib import Path
import modules.system.system as system
from tasks.pruning.prune import prune_task
from tasks.test.test import test_task
from tasks.training.tune import tune_task


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)


def main(config_path: str = ''):
    config = Config(config_path)
    c = config.get_config()
    Path(c.task.output_folder).mkdir(exist_ok=True)
    config.save_config(c.task.output_folder)
    system.init_system(c)

    # if c.report.use_logger:
    logger = utils.setup_logger(c.report, multi_process=system.world_size > 1)
    # logger = logging.getLogger(__name__)

    logger.info(f"loaded config: {config_path}")
    logger.info(c)
    logger.info(f"current cpu affinity:{psutil.cpu_count()}")

    # all_cpus = list(range(psutil.cpu_count()))
    # psutil.Process().cpu_affinity(all_cpus)
    if c.task.task_mode in ['fuse', 'tune']:
        tune_task(c)
    elif c.task.task_mode == 'prune':
        prune_task(c)
    elif c.task.task_mode == 'test':
        test_task(c)
    else:
        raise NotImplementedError


if __name__ == "__main__":
    print(" ".join(sys.argv))
    fire.Fire(main)
