import subprocess
import argparse
from pathlib import Path
import numpy as np
import os
from datetime import timedelta
import itertools
import pandas as pd


parser = argparse.ArgumentParser(description="Automatically submit jobs using SLURM")
parser.add_argument('--account', help="account")
parser.add_argument('--missing', action='store_true')
args = parser.parse_args()


Path('slurm_logs/').mkdir(parents=True, exist_ok=True)
Path('logs/').mkdir(parents=True, exist_ok=True)
time_format = "%H:%M:%S"

surv_model_list = ['DeepHit', 'NnetSurv', 'PMFSurv']
fair_model_list = ['None']
dataset_list = ['mimiccxr', 'adni', 'areds']
sensitive_attribute_list = ['age', 'sex', 'race']
metric_list = ['ctd', 'brier', 'auc']
hparams_seed_list = [str(i) for i in np.arange(10)]
seed_list = ['0']
pretrained_list = ['True']
shift_list = ['None']
group_shift_list = ['None']


def load_best_models(file_name):
    df = pd.read_csv(file_name)
    best_models = []
    for index, row in df.iterrows():
        dataset = row['dataset']
        sensitive_attribute = row['sensitive_attr']
        surv_model = row['model']
        fair_model = 'None'
        metric = row['metric']
        hparam_seed = row['hparam_seed']
        seed = '0'
        pretrain = 'True'
        shift = 'None'
        group_shift = 'None'
        best_models.append((dataset, sensitive_attribute, surv_model, fair_model, metric, hparam_seed, seed, pretrain, shift, group_shift))
    return best_models

best_model_list = load_best_models('result/tte_model_selection.csv')


def str_2_timedelta(runtime):
    runtime = runtime.split(':')
    hours = int(runtime[0])
    minutes = int(runtime[1])
    seconds = int(runtime[2])
    runtime = timedelta(hours=hours, minutes=minutes, seconds=seconds)
    return runtime


def timedelta_2_str(runtime):
    total_seconds = int(runtime.total_seconds())
    hours, remainder = divmod(total_seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    runtime = f"{hours:02}:{minutes:02}:{seconds:02}"
    return runtime


def get_runtime(dataset):
    if dataset == 'mimiccxr':
        runtime = '0:30:00'
    elif dataset == 'adni':
        runtime = '0:30:00'
    elif dataset == 'areds':
        runtime = '0:30:00'
    return runtime


def list_files_in_directory(directory):
    # List all files and directories in the specified directory
    items = os.listdir(directory)
    # Filter out directories, keep only files
    files = [item for item in items if os.path.isfile(os.path.join(directory, item))]
    return files


for best_model in best_model_list:
    dataset, sensitive_attribute, surv_model, fair_model, metric, hparams_seed, seed, pretrained, shift, group_shift = best_model
    runtime = get_runtime(dataset)
    submit_command = ("sbatch --account=%s --time=%s " % (args.account, runtime) 
        + "-o 'slurm_logs/log_repr_%s_%s_%s_%s_%s_%s_%s_%s_%s_%s.txt' " 
        % (surv_model, fair_model, dataset, sensitive_attribute, metric, pretrained, shift, group_shift, hparams_seed, seed) 
        + "--export=surv_model=%s,fair_model=%s,dataset=%s,sensitive_attribute=%s,metric=%s,hparams_seed=%s,seed=%s,pretrained=%s,shift=%s,group_shift=%s script/slurm_get_representation.sh" 
        % (surv_model, fair_model, dataset, sensitive_attribute, metric, hparams_seed, seed, pretrained, shift, group_shift))
    exit_status = subprocess.call(submit_command, shell=True)
    # Check to make sure the job submitted
    if exit_status is 1:
        print("Job {0} failed to submit".format(submit_command))
