import sys, os
import argparse
from subprocess import Popen, PIPE
import time
from datetime import date
from itertools import product

from hydra import compose, initialize
from omegaconf import OmegaConf
import tempfile
from pprint import pprint

def outer_product(sweep):
    # Extract argument names and their possible values
    keys, values = zip(*[(key, val[1]) for key, val in sweep.items()])
    # Compute the Cartesian product of the values
    product_combinations = product(*values)
    # Create a dictionary for each combination
    result = []
    for combination in product_combinations:
        combination_dict = {sweep[key][0]: value for key, value in zip(keys, combination)}
        result.append(combination_dict)
    return result

def check_for_done(l):
    for i, p in enumerate(l):
        if p.poll() is not None:
            return True, i
    return False, False


parser = argparse.ArgumentParser(description='Training arguments')
# Optimizer
# sbatch --nice=1000 --gres=gpu:$NUM_GPUS -c $NUM_CPUS run_rb.sh --num_runs ??? --gpu_ids ???
# sbatch --nice=1000 --nodes=$NUM_NODES --gres=gpu:$NUM_GPUS -c $NUM_CPUS --nodelist=... run_rb.sh --num_runs ??? --gpu_ids ???
# python3 scripts/pipe2.py --num_runs 1 --gpu_ids ??? 
parser.add_argument('--num_runs', default=1, type=int, help='Number of runs to run')
parser.add_argument('--gpu_ids', help='delimited list input', 
    type=lambda s: [int(item) for item in s.split(',')])
parser.add_argument('--wandb', action='store_true', help='To use wandb logging or not', default=True)
parser.add_argument('--logrun', action='store_true', help='save the model orn ot', default=True)
parser.add_argument('--print', 
                    action='store_true', help='to print the outputs or not', default=True)
# parser.add_argument('--num_epochs', 
#                     default=500, type=int, 
#                     help='Number of epochs to learn')
# parser.add_argument('--group_name', 
#                     default=None, type=str, 
#                     help="Group Name")

args = parser.parse_args()

print(args.gpu_ids)

CFGs_DIR = "pipe_cfgs"
# create directory if it doesn't exist-
os.makedirs(CFGs_DIR, exist_ok=True)

NUM_RUNS = args.num_runs
GPU_IDS = args.gpu_ids
NUM_GPUS = len(GPU_IDS)
USE_WANDB = args.wandb
LOG_RUN = args.logrun
PRINT = args.print
# NUM_EPOCHS = args.num_epochs
counter = 0

GROUP_NAME = "NS5_noise"
sweep = dict(
    lr=('model.optimizer.lr', [0.001]),
    noise_std = ('step.noise.std', [0.0]),
    final_eval_frequency=('final_eval_frequency', [20]),
    reduced_resolution=('dataset.dataloader.reduced_resolution',[2]),
    reduced_resolution_t=('dataset.dataloader.reduced_resolution_t',[1]),
    
    model= ('model', ['s4ffno2d',' ffno_2d']),
    
    
    use_noise = ('step.noise.use_noise', [True]),
    
    num_epochs = ('num_epochs', [200]),
    dataset= ('dataset', ['nstokes/e-5_5']),  # Add other datasets to the list as needed
    # lr=('model.optimizer.lr', [0.001,0.0005,0.002]),
    
    batch_size=('model.batch_size', [16]),
    seed=('seed', [180]),
    d_model=('model.params.d_model', [64]),
    modes=('model.params.modes', [-1]),
    t_train=('dataset.t_train', [32]),
    t_test=('dataset.t_test', [48]),
    train_timesteps=('step.train_timesteps', [16]),
    # n_layers=('model.params.n_layers', [4,8]),
    loss_type=('loss_type',['nRMSE']),

    # # version=('step.version',[1.5,1.8]),
    # stride = ('model.params.fast.stride', [2,4]),
    # fast_kernel = ('model.params.fast.kernel_size', [8]),
    
    # num_samples_max=('dataset.dataloader.num_samples_max',[9000]),
    evaluate_frequency=('evaluate_frequency', [1]),
    warmup_epochs=('model.warmup_epochs', [1]),
    # model = ('model', ['fno_1d']),
    # model= ('model', ['playground/fno-s4_based']),
    # model= ('model', ['ffno_2d']),
    

    
    

    normalize_per_trajectory = ('step.normalize_per_trajectory', [False]),

    chunk_train = ('step.chunk_train', [True]),
    unfold = ('dataset.dataloader.unfold', [False]),

    scale = ('dataset.dataloader.scale', [False]),

    # ffn = ('model.params.ffn_type', ["zero"]),

    # final_mlp_hidden_expansion = ('model.params.final_mlp_hidden_expansion',[1]),
    # activation = ('model.params.s4block_args.activation',["gelu"]),
    # norm_type = ('model.params.norm_type',["identity"]),
    # residual_type = ('model.params.residual_type',["identity","zero", ["identity","identity","identity","zero"]]),
    # final_mlp_act = ('model.params.final_mlp_act',["gelu"]),
    


    
    scheduler=('model.scheduler', ['step']),
    eta_min_factor=('model.step_size', [60])
    
    
    # reduced_resolution_t=('dataset.dataloader.reduced_resolution_t', [20]),
    # channels=('model.params.channels', [1]),
    # spectral_type=('model.params.spectral_type', ['full']),
    # residual_type=('model.params.residual_type', ['weighted']),
    # fconv_type=('model.params.fconv_type', ['standard']),
    # s4kernel=('model.params.s4block_args.kernel', ['dplr']),
    # inside_residual_type=[('model.params.inside_residual_type', ['zero'])],
    # ssm_init=('model.params.s4block_args.init', ["['legs','diag-lin','legs','diag-lin']"]),
    # ssm_init=('model.params.s4block_
    # args.init', ["['legs','diag-lin','legs','diag-lin']","['diag-lin','legs','diag-lin','legs']",'legs', 'diag-lin']),
    # random_final_upper=('step.random_train.final_upper',[2]),
    # random_final_lower=('step.random_train.final_lower',[0]),
    # random_initial_upper=('step.random_train.initial_upper',[0]),
    # train_average_steps=('step.train_evaluator.average_steps', [True]),
    # test_average_steps=('step.test_evaluator.average_steps', [True]),
    # n_timesteps=('step.n_timesteps', [2]),
    # use_noise=('step.noise.use_noise', [False]),
    # trank=('model.params.seq_model.params.s4block_args.trank',[["_EMPTY", 4, "_EMPTY", 4],["_EMPTY", 1, "_EMPTY", 1]]),
    # prob_forecasting=('prob_forecasting', ['student-t']),
    # discard_step=('step.discard_state', [True,False]),
)


sweep_product = list(outer_product(sweep))
length = len(sweep_product)
print(f"Number of runs that will be executed: {length}")
pprint(sweep)
timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
with open(f"sweeps/{timestamp}_{GROUP_NAME}.yaml", "w") as f:
    OmegaConf.save({"sweep": sweep}, f)

wandb_group = GROUP_NAME
wandb_tags_original = []


procs = list()
gpu_queue = list()
gpu_use = list()

for i in range(NUM_RUNS):
    gpu_queue.append(GPU_IDS[i % NUM_GPUS])


cfgs_paths = []
for run in sweep_product:
    overrides = []
    for key, value in run.items():
        overrides.append( f"{key}={value}" )

    overrides.extend([
        f"print={PRINT}",
        f"use_wandb={USE_WANDB}",
        f"log_model={LOG_RUN}",
        # f"num_epochs={NUM_EPOCHS}",
        f"wandb.group={wandb_group}"
    ])

    wandb_tags = wandb_tags_original 

    if len(wandb_tags) > 0:
        overrides.append( f"'wandb.tags={wandb_tags}'" )
    else:
        pass

    with initialize(version_base=None, config_path="../configs"):
        # print(overrides)
        cfg = compose(config_name="config", overrides=overrides)
        # resolve
        OmegaConf.resolve(cfg)
        # print(OmegaConf.to_yaml(cfg))
        # wait 2 seconds so that timestamp is different
        time.sleep(1)
    with tempfile.NamedTemporaryFile(mode='w', delete=False, dir = CFGs_DIR, suffix=".tmp") as f:
        OmegaConf.save(cfg, f.name)
        cfgs_paths.append(f.name)
    
for path in cfgs_paths:
    gpu_id = gpu_queue.pop(0)
    gpu_use.append(gpu_id)
    cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} python train_precomposed.py --cfg_path {path}"
    print(cmd)
    
    procs.append(Popen(cmd, shell=True))
    
    time.sleep(3)

    counter += 1


    if len(procs) == NUM_RUNS:
        wait = True
        
        while wait: 
            done, num = check_for_done(procs)
        
            if done: 
                procs.pop(num)
                gpu_queue.append(gpu_use.pop(num))
                wait = False
            else:
                time.sleep(3)    
            
        print("\n \n \n \n --------------------------- \n \n \n \n")
        print(f"{date.today()} - {counter} runs completed")
        sys.stdout.flush()
        print("\n \n \n \n --------------------------- \n \n \n \n")

for p in procs:
    p.wait()
procs = []
