import os
from datetime import datetime
import subprocess
import random

def create_sample_batch_job(epoch, sampling_step, experiment_name, run_id, ts, job_directory, condition_first, output_dir, project, partition,
                            num_gpus, hours_per_sample, minutes_per_sample, num_arrays, n_conditions, model_weights_job_id,
                            edge_conditional_set, n_samples_per_condition, total_cond_eval, shuffle, random_seed, jobids_file,
                            inpaint_on_one_reactant):
    
    print(f"Creating job {experiment_name}... ")
    job_file = os.path.join(job_directory, f"{experiment_name}.job")

    with open(job_file, 'w') as fh:
        fh.writelines("#!/bin/bash\n")
        fh.writelines(f"#SBATCH --job-name={experiment_name}_%a.job\n")
        fh.writelines(f"#SBATCH --output={output_dir}/{experiment_name}_%a.out\n")
        fh.writelines(f"#SBATCH --error={output_dir}/{experiment_name}_%a.err\n")
        fh.writelines(f"#SBATCH --account={project}\n")
        fh.writelines(f"#SBATCH --partition={partition}\n")
        fh.writelines(f"#SBATCH --gres=gpu:{num_gpus}\n")
        fh.writelines("#SBATCH --cpus-per-task=7\n")
        fh.writelines("#SBATCH --mem-per-cpu=10G\n")
        fh.writelines(f"#SBATCH --time={hours_per_sample}:{minutes_per_sample}:00\n")
        fh.writelines(f"#SBATCH --array=0-{num_arrays-1}\n")
        fh.writelines(f"#SBATCH --dependency=afterok:{model_weights_job_id}\n")
        # TODO: ADD HERE IMPORTS THAT YOU NEED IN YOUR SERVER
        fh.writelines(f'\n')
        fh.writelines(f"python3 src/sample_array_job.py general.wandb.mode='offline' general.wandb.run_id={run_id} diffusion.edge_conditional_set={edge_conditional_set} "+\
                    f" 'general.wandb.checkpoint_epochs=[{epoch}]' test.condition_first={condition_first} test.condition_index=$SLURM_ARRAY_TASK_ID test.n_conditions={n_conditions} "+\
                    f" test.n_samples_per_condition={n_samples_per_condition} dataset.shuffle={shuffle} general.wandb.load_run_config=True " +\
                    f" hydra.run.dir=experiments/{experiment_name}/ test.total_cond_eval={total_cond_eval} train.seed={random_seed()} 'diffusion.diffusion_steps_eval={sampling_step}' "+\
                    f" test.inpaint_on_one_reactant={inpaint_on_one_reactant} ")

    result = subprocess.Popen(["sbatch", job_file], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = result.communicate()
    print(stderr)
    if 'job' not in stdout.decode("utf-8"):
        print(result)
    else:
        sampling_job_id = stdout.decode("utf-8").strip().split('job ')[1]

        with open(jobids_file, 'a') as f:
            f.write(f"{experiment_name}.job: {sampling_job_id}\n")
        print(f"=== Sampling. Slurm ID ={sampling_job_id}.")
            
    return sampling_job_id

def create_evaluate_batch_job(epoch, sampling_step, run_id, experiment_name, experiment_folder, job_directory, condition_first, output_dir, project, partition,
                              num_gpus, hours_per_eval, minutes_per_eval, num_arrays, n_conditions, load_samples_to_wandb_job_id,
                              edge_conditional_set, n_samples_per_condition, total_cond_eval, shuffle, random_seed, jobids_file, ts):
    print(f"Creating job {experiment_name}... ")
    job_file = os.path.join(job_directory, f"{experiment_name}.job")
    
    with open(job_file, 'w') as fh:
        fh.writelines("#!/bin/bash\n")
        fh.writelines(f"#SBATCH --job-name={experiment_name}_%a.job\n") # add time stamp?
        fh.writelines(f"#SBATCH --output={output_dir}/{experiment_name}_%a.out\n")
        fh.writelines(f"#SBATCH --error={output_dir}/{experiment_name}_%a.err\n")
        fh.writelines(f"#SBATCH --account={project}\n")
        fh.writelines(f"#SBATCH --partition={partition}\n")
        fh.writelines(f"#SBATCH --gres=gpu:{num_gpus}\n")
        fh.writelines("#SBATCH --cpus-per-task=7\n")
        fh.writelines("#SBATCH --mem-per-cpu=20G\n")
        fh.writelines(f"#SBATCH --time={hours_per_eval}:{minutes_per_eval}:00\n")
        fh.writelines(f"#SBATCH --array=0-{num_arrays-1}\n")
        if load_samples_to_wandb_job_id != None:
            fh.writelines(f"#SBATCH --dependency=afterok:{load_samples_to_wandb_job_id}\n")
        # TODO: ADD HERE IMPORTS THAT YOU NEED IN YOUR SERVER
        fh.writelines(f'\n')
        fh.writelines(f"python3 src/evaluate_array_job.py general.wandb.mode='offline' general.wandb.load_run_config=True general.wandb.run_id={run_id} diffusion.edge_conditional_set={edge_conditional_set} "+\
                    f" 'general.wandb.checkpoint_epochs=[{epoch}]' test.condition_first={condition_first} test.condition_index=$SLURM_ARRAY_TASK_ID test.n_conditions={n_conditions} "+\
                    f" test.n_samples_per_condition={n_samples_per_condition} dataset.shuffle={shuffle} test.total_cond_eval={total_cond_eval} " +\
                    f" hydra.run.dir=experiments/{experiment_folder}/ train.seed={random_seed()} diffusion.diffusion_steps_eval={sampling_step} ")

    result = subprocess.Popen(["sbatch", job_file], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = result.communicate()
    print(stderr)
    if 'job' not in stdout.decode("utf-8"):
        print(result)
    else:
        eval_job_id = stdout.decode("utf-8").strip().split('job ')[1]

        with open(jobids_file, 'a') as f:
            f.write(f"{experiment_name}.job: {eval_job_id}\n")
        print(f"=== Sampling. Slurm ID ={eval_job_id}.")
    return eval_job_id

def create_and_run_sample_eval_pipeline(job_directory, output_dir, project, partition, run_id, epochs,
                                        hours_per_sample, minutes_per_sample, hours_per_eval, minutes_per_eval,
                                        num_arrays, n_conditions, total_cond_eval,
                                        n_samples_per_condition, edge_conditional_set, sampling_step_counts,
                                        random_seed=lambda: random.randint(0,1000), shuffle=True,
                                        condition_first_random_func=lambda: 0,
                                        inpaint_on_one_reactant=False):
    """Inputs:
    job_directory: e.g., os.path.join('outputs/sample_eval_array_job')
    output_dir:     os.path.join(job_directory, 'out') (Contains the slurm error and out files)
    project:        Which slurm server project
    partition:      Slurm server partition
    run_id:         The wandb run id to evaluate on
    epochs:         The epochs to evaluate on, as a list
    hours_per_sample: Num of hours to reserve on sampling (individual 1-GPU job in the array job)
    minutes_per_sample: Num of minutes to reserve on sampling
    hours_per_eval: Num of hours to reserve on eval
    minutes_per_eval: Num of minutes to reserve on eval
    num_arrays:     Number of parallel jobs to compute sampling and evaluation in
    n_conditions:   Number of conditions to process per job
    total_cond_eval: Total conditions processed, should be n_conditions * total_cond_eval. This is here to make sure you know exactly what you are doing!
    n_samples_per_condition: Number of samples per condition
    edge_conditional_set: Data subset to draw conditions from (train/val/test)
    random_seed: A function that returns a random seed to use. Can vary for each different epoch, or can be set to constant with lambda: 1
    shuffle: Whether to shuffle the data set. True by default, but maybe better to set to False for evaluation of the entire data set just to be sure
    condition_first_random_func: a function that returns a number to use to define the condition range for each sampling array job. By default, just deterministic.
    sampling_step_counts: A list containing the number of sampling steps to evaluate on. Having len(sampling_step_counts) is mutually exclusive with having len(epochs) == 1.
    inpaint_on_one_reactant: Whether to inpaint on a randomly chosen reactant.
    """

    assert len(epochs) == 1 or len(sampling_step_counts) == 1, "Either epochs or sampling_step_counts should have length 1"

    if len(sampling_step_counts) == 1:
        condition_firsts = [condition_first_random_func() for _ in epochs]
    else:
        condition_firsts = [condition_first_random_func() for _ in sampling_step_counts]

    ts = int(round(datetime.now().timestamp()))
    jobids_file = os.path.join(job_directory, 'jobids.txt')
    num_gpus = 1
    if not partition=='standard-g':
        assert total_cond_eval == num_arrays * n_conditions
    else:
        assert total_cond_eval == num_arrays * 8 * n_conditions

    '''
        1. load the model weights first
        
        Notes:
        - the script will download all the model weights specified in the variable epochs=[...] 
        - this is done in a  separate job because it happened before that parallel jobs tried to download the wandb artifact
        and some failed (error: python reading corrupt file)
        - the script first checks if the weights file exists, if it doesn't it downloads it from wandb
    '''
    job_name = f"download_model_weights_{ts}"
    job_file = os.path.join(job_directory, f"{job_name}.job")
    with open(job_file, 'w') as fh:
        fh.writelines("#!/bin/bash\n")
        fh.writelines(f"#SBATCH --job-name={job_name}.job\n") # add time stamp?
        fh.writelines(f"#SBATCH --output={output_dir}/{job_name}.out\n")
        fh.writelines(f"#SBATCH --error={output_dir}/{job_name}.err\n")
        fh.writelines(f"#SBATCH --account={project}\n")
        fh.writelines(f"#SBATCH --partition=????\n")
        fh.writelines(f"#SBATCH --cpus-per-task=7\n")
        fh.writelines(f"#SBATCH --time=00:10:00\n")
        # TODO: ADD HERE IMPORTS THAT YOU NEED IN YOUR SERVER
        fh.writelines(f'\n')
        fh.writelines(f"python3 src/wandb_download_model_weights.py general.wandb.run_id={run_id} 'general.wandb.checkpoint_epochs={epochs}' \n")

    result = subprocess.Popen(["sbatch", job_file], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = result.communicate()

    if 'job' not in stdout.decode("utf-8"):
        print(result)
    else:
        model_weights_job_id = stdout.decode("utf-8").strip().split('job ')[1]

        with open(jobids_file, 'a') as f:
            f.write(f"{job_name}.job: {model_weights_job_id}\n")
        print(f"=== Getting model weights. Slurm ID = {model_weights_job_id}.")
        
    '''
        2. Sample from the full dataset parallelized at the level of slurm jobs.
        
        Note: 
        - job will only run if downloading the model weights is successful (see line dependency=afterok:model_weights_job_id)
        - for upsto50k test data, array job range should be 0-49
        - the start index of the condition_range for each array job is defined as condition_start*n_conditions
        - you can sample from multiple checkpoints at once by submitting multiple jobs => 
        add more numbers to the epochs variable
        - it takes about 2 hours to sample n_conditions=100 and n_samples_per_condition=100
    '''
    all_sampling_job_ids = []
    experiment_name = f"{run_id}_sample_{ts}"#_{ts}"
    if len(sampling_step_counts) == 1:
        for i,e in enumerate(epochs):
            sampling_job_id = create_sample_batch_job(e, sampling_step_counts[0], experiment_name, run_id, ts, job_directory, condition_firsts[i], output_dir, project, partition,
                                num_gpus, hours_per_sample, minutes_per_sample, num_arrays, n_conditions, model_weights_job_id,
                                edge_conditional_set, n_samples_per_condition, total_cond_eval, shuffle, random_seed, jobids_file, inpaint_on_one_reactant)
            all_sampling_job_ids.append(sampling_job_id)
    else:
        for i,s in enumerate(sampling_step_counts):
            sampling_job_id = create_sample_batch_job(epochs[0], s, experiment_name, run_id, ts, job_directory, condition_firsts[i], output_dir, project, partition,
                                num_gpus, hours_per_sample, minutes_per_sample, num_arrays, n_conditions, model_weights_job_id,
                                edge_conditional_set, n_samples_per_condition, total_cond_eval, shuffle, random_seed, jobids_file, inpaint_on_one_reactant)
            all_sampling_job_ids.append(sampling_job_id)

    # '''
    # 3. Submit job to upload to wandb
    # '''
    job_name = f"wandblog_{ts}"
    job_file = os.path.join(job_directory, f"{job_name}.job")
    with open(job_file, 'w') as fh:
        fh.writelines("#!/bin/bash\n")
        fh.writelines(f"#SBATCH --job-name={job_name}.job\n")
        fh.writelines(f"#SBATCH --output={output_dir}/{job_name}.out\n")
        fh.writelines(f"#SBATCH --error={output_dir}/{job_name}.err\n")
        fh.writelines(f"#SBATCH --account={project}\n")
        fh.writelines(f"#SBATCH --partition=????\n")
        fh.writelines("#SBATCH --cpus-per-task=7\n")
        fh.writelines(f"#SBATCH --time=00:30:00\n")
        fh.writelines(f"#SBATCH --dependency=afterok:{','.join(all_sampling_job_ids)}\n")
        # TODO: ADD HERE IMPORTS THAT YOU NEED IN YOUR SERVER
        fh.writelines(f'\n')
        fh.writelines(f"python3 src/sample_wandblog.py general.wandb.run_id={run_id} diffusion.edge_conditional_set={edge_conditional_set} "+\
                        f" 'general.wandb.checkpoint_epochs={epochs}' test.n_conditions={n_conditions} test.n_samples_per_condition={n_samples_per_condition} " +\
                        f" hydra.run.dir=experiments/{experiment_name}/ test.total_cond_eval={total_cond_eval} 'general.wandb.eval_sampling_steps={sampling_step_counts}' ")

    result = subprocess.Popen(["sbatch", job_file], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = result.communicate()
    print(stderr)
    if 'job' not in stdout.decode("utf-8"):
        print(result)
    else:
        load_samples_to_wandb_job_id = stdout.decode("utf-8").strip().split('job ')[1]

        with open(jobids_file, 'a') as f:
            f.write(f"{job_name}.job: {load_samples_to_wandb_job_id}\n")
        print(f"=== Loading models to wandb. Slurm ID = {load_samples_to_wandb_job_id}.")
        
    # '''
    #     4. eval from the sample file given
        
    #     Note: 
    #     - job will only run if downloading the model weights is successful (see line dependency=afterok:model_weights_job_id)
    #     - for upsto50k test data, array job range should be 0-49
    #     - the start index of the condition_range for each array job is defined as condition_start*n_conditions (eg. 0, 100, ..., 4900)
    #     - you can sample from multiple checkpoints at once by submitting multiple jobs => 
    #     add more numbers to the epochs variable
    # '''
    all_eval_job_ids = []
    experiment_folder = experiment_name # f"{run_id}_sample_{ts}"
    experiment_name = f"{run_id}_eval_{ts}"
    if len(sampling_step_counts) == 1:
        for i, e in enumerate(epochs):
            eval_job_id = create_evaluate_batch_job(e, sampling_step_counts[0], run_id, experiment_name, experiment_folder, job_directory, condition_firsts[i], output_dir, project, partition,
                                num_gpus, hours_per_eval, minutes_per_eval, num_arrays, n_conditions, load_samples_to_wandb_job_id,
                                edge_conditional_set, n_samples_per_condition, total_cond_eval, shuffle, random_seed, jobids_file, ts)
            all_eval_job_ids.append(eval_job_id)
    else:
        for i, s in enumerate(sampling_step_counts):
            eval_job_id = create_evaluate_batch_job(epochs[0], s, run_id, experiment_name, experiment_folder, job_directory, condition_firsts[i], output_dir, project, partition,
                                num_gpus, hours_per_eval, minutes_per_eval, num_arrays, n_conditions, load_samples_to_wandb_job_id,
                                edge_conditional_set, n_samples_per_condition, total_cond_eval, shuffle, random_seed, jobids_file, ts)
            all_eval_job_ids.append(eval_job_id)

    '''
        5. load eval to wandb
    '''
    condition_firsts_all_same = all([c == condition_firsts[0] for c in condition_firsts])
    # submit job to upload to wandb
    job_name = f"eval_wandblog_{ts}"
    job_file = os.path.join(job_directory, f"{job_name}.job")
    with open(job_file, 'w') as fh:
        fh.writelines("#!/bin/bash\n")
        fh.writelines(f"#SBATCH --job-name={job_name}.job\n") # add time stamp?
        fh.writelines(f"#SBATCH --output={output_dir}/{job_name}.out\n")
        fh.writelines(f"#SBATCH --error={output_dir}/{job_name}.err\n")
        fh.writelines(f"#SBATCH --account={project}\n")
        fh.writelines(f"#SBATCH --partition=small\n")
        fh.writelines("#SBATCH --cpus-per-task=7\n")
        fh.writelines(f"#SBATCH --time=00:30:00\n")
        fh.writelines(f"#SBATCH --dependency=afterok:{','.join(all_eval_job_ids)}\n")
        # TODO: ADD HERE IMPORTS THAT YOU NEED IN YOUR SERVER
        fh.writelines(f'\n')
        fh.writelines(f"python3 src/evaluate_wandblog.py "+\
                        f" 'general.wandb.checkpoint_epochs={epochs}' " +\
                        f" hydra.run.dir=experiments/{experiment_folder}/ general.wandb.run_id={run_id} " +\
                        f" test.total_cond_eval={total_cond_eval} test.n_samples_per_condition={n_samples_per_condition} diffusion.edge_conditional_set={edge_conditional_set} " +\
                        f" 'general.wandb.eval_sampling_steps={sampling_step_counts}' test.condition_first={condition_firsts[0] if condition_firsts_all_same else -1} ")

    result = subprocess.Popen(["sbatch", job_file], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = result.communicate()
    print(stderr)
    if 'job' not in stdout.decode("utf-8"):
        print(result)
    else:
        load_samples_to_wandb_job_id = stdout.decode("utf-8").strip().split('job ')[1]

        with open(jobids_file, 'a') as f:
            f.write(f"{job_name}.job: {load_samples_to_wandb_job_id}\n")
        print(f"=== Loading models to wandb. Slurm ID = {load_samples_to_wandb_job_id}.")