import sys, os
import argparse
from subprocess import Popen, PIPE
import time
from datetime import date
from itertools import product

def outer_product(inp):
    return (dict(zip(inp.keys(), values)) for values in product(*inp.values()))

def check_for_done(l):
    for i, p in enumerate(l):
        if p.poll() is not None:
            return True, i
    return False, False


parser = argparse.ArgumentParser(description='Training arguments')
# Optimizer
# python3 scripts/pipe.py --num_runs 4 --gpu_ids 0 --wandb --print --num_epochs 500 --group_name
parser.add_argument('--num_runs', default=1, type=int, help='Number of runs to run')
parser.add_argument('--gpu_ids', help='delimited list input', 
    type=lambda s: [int(item) for item in s.split(',')])
parser.add_argument('--wandb', action='store_true', help='To use wandb logging or not')
parser.add_argument('--logrun', action='store_true', help='save the model orn ot')
parser.add_argument('--print', 
                    action='store_true', help='to print the outputs or not')
parser.add_argument('--num_epochs', 
                    default=500, type=int, 
                    help='Number of epochs to learn')
parser.add_argument('--group_name', 
                    default=None, type=str, 
                    help="Group Name")

args = parser.parse_args()

print(args.gpu_ids)

NUM_RUNS = args.num_runs
GPU_IDS = args.gpu_ids
NUM_GPUS = len(GPU_IDS)
USE_WANDB = args.wandb
LOG_RUN = args.logrun
PRINT = args.print
NUM_EPOCHS = args.num_epochs
counter = 0


sweep = dict(
    model=['s4dualseq'],
    # model=['fno_1d'],
    # model=['fno_1d_linear_conv'],
    # dataset=['burgers_1d_markov'],
    dataset=['ks_1d'],
    # batch_size=[32],
    step=['sequential'],
    # warmup_epochs=[0,2],
    lr=[0.005],
    # seed = [3344,4455],
    # n_layers=[2,4],
    # d_model=[32, 64, 128],
    # reduced_resolution_t=[20],
    # channels=[1],
    # scheduler=['cosine'],
    # modes=[64, 128],
    # spectral_type=[, # full unless experimenting, fno spectral type
    # residual_type=['weighted'], # weighted unless experimenting, fno residual connection
    # fconv_type=['standard'], # standard unless experimenting, s4 convolution type
    s4kernel=['s4d'],
    # inside_residual_type=['zero'], # zero unless experimenting, residual connection INSIDE S4Block (right after FFTConv)
    # ssm_init=['legs','diag-lin'],
    t_train=[11,21],
)

sweep_product = list(outer_product(sweep))
length = len(sweep_product)
print(f"Number of runs that will be executed: {length}")

if args.group_name is not None:
    wandb_group = args.group_name
else:
    wandb_group = f"dec27_channel_increase"
wandb_tags_original = []


procs = list()
gpu_queue = list()
gpu_use = list()

for i in range(NUM_RUNS):
    gpu_queue.append(GPU_IDS[i % NUM_GPUS])


for run in sweep_product:
    gpu_id = gpu_queue.pop(0)
    gpu_use.append(gpu_id)
    cmd = (
        f"CUDA_VISIBLE_DEVICES={gpu_id} "
        f"python train.py "
        f"num_epochs={NUM_EPOCHS} "
        f"dataset={run['dataset']} "
        # f"batch_size={run['batch_size']} "
        f"model={run['model']} "
        f"model.optimizer.lr={run['lr']} "
        # f"model.params.d_model={run['d_model']} "
        # f"model.params.n_layers={run['n_layers']} "
        # f"model.scheduler={run['scheduler']} "
        # f"seed={run['seed']} "
        # f"model.warmup_epochs={run['warmup_epochs']} "
        # f"dataset.dataloader.reduced_resolution_t={run['reduced_resolution_t']} "
        f"step={run["step"]} "
        f"log_model={LOG_RUN} "
        f"wandb.group={wandb_group} "
        f"print={PRINT} "
        f"use_wandb={USE_WANDB} "
        f"dataset.t_train={run['t_train']} "
    )
    if run['model'].startswith('fno'):
        cmd += (
            f"model.params.modes={run['modes']} "
            # f"model.params.residual_type={run['residual_type']} "
            # f"model.params.spectral_type={run['spectral_type']} "
        )
    elif run['model'].startswith('s4'):
        cmd += (
            # f"model.params.channels={run['channels']} "
            # f"model.params.fconv_type={run['fconv_type']} "
            # f"model.params.inside_residual_type={run['inside_residual_type']} "
            f"model.params.s4block_args.kernel={run['s4kernel']} "
            # f"model.params.s4block_args.init={run['ssm_init']} "
        )
    else:
        ValueError("invalid model")

    wandb_tags = wandb_tags_original 

    if len(wandb_tags) > 0:
        cmd = cmd + f"'wandb.tags={wandb_tags}' "
    else:
        cmd = cmd 

    print(cmd)
    
    procs.append(Popen(cmd, shell=True))
    
    time.sleep(3)

    counter += 1


    if len(procs) == NUM_RUNS:
        wait = True
        
        while wait: 
            done, num = check_for_done(procs)
        
            if done: 
                procs.pop(num)
                gpu_queue.append(gpu_use.pop(num))
                wait = False
            else:
                time.sleep(3)    
            
        print("\n \n \n \n --------------------------- \n \n \n \n")
        print(f"{date.today()} - {counter} runs completed")
        sys.stdout.flush()
        print("\n \n \n \n --------------------------- \n \n \n \n")

for p in procs:
    p.wait()
procs = []
