from pathlib import Path
import logging
import os
import uuid
import subprocess
import yaml
import pprint

import submitit
import numpy as np
import argparse

logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
_logger = logging.getLogger('train')

parser = argparse.ArgumentParser(description='Paws distributed')

# Slurm setting
parser.add_argument('--ngpus-per-node', default=6, type=int, metavar='N',
                    help='number of gpus per node')
parser.add_argument('--nodes', default=5, type=int, metavar='N',
                    help='number of nodes')
parser.add_argument("--timeout", default=360, type=int,
                    help="Duration of the job")
parser.add_argument("--partition", default="el8", type=str,
                    help="Partition where to submit")
parser.add_argument("--exp", default="run0", type=str,
                    help="track experiment")

parser.add_argument('--checkpoint-dir', type=Path,
                    metavar='DIR', help='path to checkpoint directory')
parser.add_argument('--log-dir', type=Path,
                    metavar='LOGDIR', help='path to tensorboard log directory')
parser.add_argument('--ema_decay', default=0.996, type=float)
parser.add_argument('--final_ema_decay', default=1., type=float)
parser.add_argument("--mom_warmup_epochs", default=0, type=int)

parser.add_argument("--num_epochs", default=100, type=int)
parser.add_argument("--use_ema", action='store_true')
parser.add_argument("--use_swa", action='store_true')
parser.add_argument("--swa_warmup", default=0, type=int)

parser.add_argument("--no_online", action='store_true')
parser.add_argument("--use_mom_scheduler", action='store_true')
parser.add_argument("--yaml_file", default= 'imagenet', type=str)
parser.add_argument("--use_pred", action='store_true')
parser.add_argument("--input_lr", default =-1, type=float)
parser.add_argument("--input_pred", default =-1, type=int)

# parser.add_argument("--fname", type=str,
#                     help='yaml file containing config file names to launch',
#                     default='configs/paws/imgnt_train.yaml')
# parser.add_argument(
#     '--sel', type=str,
#     help='which script to run',
#     choices=[
#         'paws_train',
#         'suncet_train',
#         'fine_tune',
#         'snn_fine_tune'
#     ])

class Trainer(object):
    def __init__(self, args):
        self.args = args

    def __call__(self):
        import train_paws_dist
        self._setup_gpu_args()
        # -- load script params
        params = None
        if self.args.yaml_file == 'cifar10':
            yaml_file = 'configs/paws/cifar10_train.yaml'
        elif self.args.yaml_file == 'imagenet':
            yaml_file = 'configs/paws/imgnt_train.yaml'
        elif self.args.yaml_file == 'imagenet_1percent':
            yaml_file = 'configs/paws/imgnt_train_1percent.yaml'

        with open(yaml_file, 'r') as y_file:
            params = yaml.load(y_file, Loader=yaml.FullLoader)
            # if load_model is not None:
                # params['meta']['load_checkpoint'] = load_model
            # logger.info('loaded params...')
            pp = pprint.PrettyPrinter(indent=4)
            pp.pprint(params)

        train_paws_dist.main_worker(self.args.gpu, self.args, params)

    def checkpoint(self):
        import os
        import submitit

        self.args.dist_url = get_init_file(self.args).as_uri()
        checkpoint_file = os.path.join(self.args.checkpoint_dir, "checkpoint.pth")
        print(checkpoint_file, "checkpoint file exists: ", os.path.exists(checkpoint_file))
        if os.path.exists(checkpoint_file):
            self.args.resume = checkpoint_file
            print("resuming from ", self.args.resume)
        print("Requeuing ", self.args)
        empty_trainer = type(self)(self.args)
        return submitit.helpers.DelayedSubmission(empty_trainer)

    def _setup_gpu_args(self):
        import submitit

        job_env = submitit.JobEnvironment()
        self.args.gpu = job_env.local_rank
        self.args.rank = job_env.global_rank
        self.args.world_size = job_env.num_tasks
        print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")


def get_init_file(args):
    # Init file must not exist, but it's parent dir must exist.
    os.makedirs(args.job_dir, exist_ok=True)
    init_file = args.job_dir / f"{uuid.uuid4().hex}_init"
    if init_file.exists():
        os.remove(str(init_file))
    return init_file


def main():
    args = parser.parse_args()

    args.checkpoint_dir = args.checkpoint_dir / args.exp
    args.log_dir = args.log_dir / args.exp
    args.job_dir = args.checkpoint_dir

    args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
    args.log_dir.mkdir(parents=True, exist_ok=True)

    get_init_file(args)

    # Note that the folder will depend on the job_id, to easily track experiments
    executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)

    num_gpus_per_node = args.ngpus_per_node
    nodes = args.nodes
    timeout_min = args.timeout
    partition = args.partition

    kwargs = {'slurm_gres': f'gpu:{num_gpus_per_node}', }

    executor.update_parameters(
        mem_gb=30 * num_gpus_per_node,
        gpus_per_node=num_gpus_per_node,
        tasks_per_node=num_gpus_per_node,  # one task per GPU
        cpus_per_task=24,
        nodes=nodes,
        timeout_min=timeout_min,  # max is 60 * 6
        # Below are cluster dependent parameters
        slurm_partition=partition,
        slurm_signal_delay_s=120,
        **kwargs
    )

    executor.update_parameters(name=args.exp)

    args.dist_url = get_init_file(args).as_uri()

    trainer = Trainer(args)
    job = executor.submit(trainer)

    print(job)
    print("Submitted job_id:",job.job_id)
    # _logger.info("Submitted job_id:", job.job_id)


if __name__ == '__main__':
    main()
