"""
Launch multiple training runs on a single machine.

Author:
Date: January 1, 2023
"""
import argparse
from dataclasses import dataclass
import datetime
from typing import Any, Dict
import os
import subprocess
import time

import numpy as np


from krt import KRT_PATH


###########################################################################
# %% Parse the arguments.
###########################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--script', type=str,
                    default='scripts/train_convnp.py')
# Name of the experiment.
parser.add_argument('--name', type=str)
# Models to use. e.g. gnpa,tnpa
parser.add_argument('--model', type=str, required=True)
# Data sources to use. e.g. fixed_1d_1k,fixed_4d_1M
parser.add_argument('--data', type=str, required=True)
# Number of seeds to run.
parser.add_argument('--num_seeds', type=int, required=True)
# How many jobs per gpu are currently available.
# e.g. 0,2,1,3 means we could be 0 jobs on gpu0, 2 jobs on gpu1, etc.
parser.add_argument('--avail', type=str, required=True)
# If this is an experimental run, then seeds will start at 100 and experimental
# will be appended to the name.
parser.add_argument('--experimental', action='store_true')
# Amount to offset the seed by.
parser.add_argument('--seed_offset', type=int, default=0)
args = parser.parse_args()

###########################################################################
# %% Initialize global variables.
###########################################################################
job_log_dir = 'job_log'
if not os.path.exists(job_log_dir):
    os.makedirs(job_log_dir)
LOG_PATH = os.path.join(job_log_dir, f'{datetime.datetime.now()}.txt')
AVAIL_JOB_COUNTS = [int(g) for g in args.avail.split(',')]
MAX_RUNNING = sum(AVAIL_JOB_COUNTS)
RUNNING = []
seed_offset = 100 * args.experimental
if args.seed_offset:
    seed_offset += args.seed_offset
ARG_DICT = {
    'model': [m for m in args.model.split(',')],
    'data': [d for d in args.data.split(',')],
    'seed': [i + seed_offset for i in range(args.num_seeds)],
}


###########################################################################
# %% Helper functions and classes.
###########################################################################
@dataclass
class Job:
    proc: subprocess.Popen
    gpu: int
    job_args: Dict[str, Any]


def prune_completed_job():
    for jidx, job in enumerate(RUNNING):
        if job.proc.poll() is not None:
            AVAIL_JOB_COUNTS[job.gpu] += 1
            with open(LOG_PATH, 'a') as f:
                f.write(f'{datetime.datetime.now()}\t Finished \t {job.job_args}\n')
            RUNNING.pop(jidx)
            return True
    return False


def add_job(job_args):
    if len(RUNNING) >= MAX_RUNNING:
        while not prune_completed_job():
            time.sleep(30)
    with open(LOG_PATH, 'a') as f:
        f.write(f'{datetime.datetime.now()}\t Starting \t {job_args}\n')
    # Find open gpu device.
    gpu = 0
    while gpu < len(AVAIL_JOB_COUNTS) - 1 and AVAIL_JOB_COUNTS[gpu] <= 0:
        gpu += 1
    AVAIL_JOB_COUNTS[gpu] -= 1
    cmd = f'python {args.script} '
    prefix = args.name if args.name is not None else str(datetime.date.today())
    if args.experimental:
        prefix = 'experimental/' + prefix
    name = f'{prefix}/{job_args["data"]}/{job_args["model"]}'
    cmd += f'name={name} '
    cmd += ' '.join([f'{k}={v}' for k, v in job_args.items()])
    cmd += f' cuda_device={gpu}'
    proc = subprocess.Popen(cmd, shell=True)
    RUNNING.append(Job(proc, gpu, job_args))


###########################################################################
# %% Run it!
###########################################################################
os.chdir(KRT_PATH)
with open(LOG_PATH, 'w') as f:
    f.write('Timestamp \t Status \t Args\n')
arg_keys = list(ARG_DICT.keys())
num_each_args = np.array([len(ARG_DICT[k]) for k in arg_keys])
arg_idxs = np.array([0 for _ in range(len(ARG_DICT))])
while True:
    add_job({k: ARG_DICT[k][arg_idxs[kidx]] for kidx, k, in enumerate(arg_keys)})
    arg_idxs[0] += 1
    for ii in range(len(arg_idxs) - 1):
        if arg_idxs[ii] >= num_each_args[ii]:
            arg_idxs[ii] = 0
            arg_idxs[ii + 1] += 1
    if np.any(arg_idxs >= num_each_args):
        break
while len(RUNNING) > 0:
    prune_completed_job()
    time.sleep(30)
with open(LOG_PATH, 'w') as f:
    f.write('Done!')
