import subprocess
import argparse
from pathlib import Path
import numpy as np
import os
from datetime import timedelta
import itertools


parser = argparse.ArgumentParser(description="Automatically submit jobs using SLURM")
parser.add_argument('--account', help="account")
parser.add_argument('--missing', action='store_true')
args = parser.parse_args()


Path('slurm_logs/').mkdir(parents=True, exist_ok=True)
Path('logs/').mkdir(parents=True, exist_ok=True)
time_format = "%H:%M:%S"

surv_model_list = ['DeepHit', 'NnetSurv', 'PMFSurv']
fair_model_list = ['None']
dataset_list = ['mimiccxr', 'adni', 'areds']
sensitive_attribute_list = ['age', 'sex', 'race']
metric_list = ['ctd', 'brier', 'auc']
hparams_seed_list = [str(i) for i in np.arange(10)]
seed_list = ['0']
pretrained_list = ['True', 'False']
shift_list = ['None']
group_shift_list = ['None']


def str_2_timedelta(runtime):
    runtime = runtime.split(':')
    hours = int(runtime[0])
    minutes = int(runtime[1])
    seconds = int(runtime[2])
    runtime = timedelta(hours=hours, minutes=minutes, seconds=seconds)
    return runtime


def timedelta_2_str(runtime):
    total_seconds = int(runtime.total_seconds())
    hours, remainder = divmod(total_seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    runtime = f"{hours:02}:{minutes:02}:{seconds:02}"
    return runtime


def get_runtime(dataset):
    if dataset == 'mimiccxr':
        runtime = '4:00:00'
    elif dataset == 'adni':
        runtime = '4:00:00'
    elif dataset == 'areds':
        runtime = '4:00:00'
    return runtime


def list_files_in_directory(directory):
    # List all files and directories in the specified directory
    items = os.listdir(directory)
    # Filter out directories, keep only files
    files = [item for item in items if os.path.isfile(os.path.join(directory, item))]
    return files


if args.missing:
    print('Running failed jobs')
    directory_path = 'output'
    for fair_model in fair_model_list:
        files = list_files_in_directory(os.path.join(directory_path, fair_model))
        for dataset, sensitive_attribute, hparams_seed, seed, surv_model, metric, pretrained, shift, group_shift in itertools.product(
            dataset_list, sensitive_attribute_list, hparams_seed_list, seed_list, surv_model_list, metric_list, pretrained_list, shift_list, group_shift_list):
            if (dataset == 'adni' and sensitive_attribute == 'race'):
                continue
            file_name = 'score_%s_%s_%s_%s_%s_%s_%s_%s_%s_%s.pkl' % (surv_model, fair_model, dataset, sensitive_attribute, metric, pretrained, shift, group_shift, hparams_seed, seed)
            if file_name not in files:
                print('Missing file:', file_name)
                runtime = get_runtime(dataset)
                submit_command = ("sbatch --account=%s --time=%s " % (args.account, runtime) 
                    + "-o 'slurm_logs/log_%s_%s_%s_%s_%s_%s_%s_%s_%s_%s.txt' " 
                    % (surv_model, fair_model, dataset, sensitive_attribute, metric, pretrained, shift, group_shift, hparams_seed, seed) 
                    + "--export=surv_model=%s,fair_model=%s,dataset=%s,sensitive_attribute=%s,metric=%s,hparams_seed=%s,seed=%s,pretrained=%s,shift=%s,group_shift=%s script/slurm_train.sh" 
                    % (surv_model, fair_model, dataset, sensitive_attribute, metric, hparams_seed, seed, pretrained, shift, group_shift))
                exit_status = subprocess.call(submit_command, shell=True)
                # Check to make sure the job submitted
                if exit_status is 1:
                    print("Job {0} failed to submit".format(submit_command))
else:
    for dataset, sensitive_attribute, hparams_seed, seed, surv_model, fair_model, metric, pretrained, shift, group_shift in itertools.product(
        dataset_list, sensitive_attribute_list, hparams_seed_list, seed_list, surv_model_list, fair_model_list, metric_list, pretrained_list, shift_list, group_shift_list):
        if (dataset == 'adni' and sensitive_attribute == 'race'):
            continue
        runtime = get_runtime(dataset)
        submit_command = ("sbatch --account=%s --time=%s " % (args.account, runtime) 
            + "-o 'slurm_logs/log_%s_%s_%s_%s_%s_%s_%s_%s_%s_%s.txt' " 
            % (surv_model, fair_model, dataset, sensitive_attribute, metric, pretrained, shift, group_shift, hparams_seed, seed) 
            + "--export=surv_model=%s,fair_model=%s,dataset=%s,sensitive_attribute=%s,metric=%s,hparams_seed=%s,seed=%s,pretrained=%s,shift=%s,group_shift=%s script/slurm_train.sh" 
            % (surv_model, fair_model, dataset, sensitive_attribute, metric, hparams_seed, seed, pretrained, shift, group_shift))
        exit_status = subprocess.call(submit_command, shell=True)
        # Check to make sure the job submitted
        if exit_status is 1:
            print("Job {0} failed to submit".format(submit_command))

