import itertools
import os
import argparse
import numpy as np
from templates import sh_templates


def main():

    current_dir = "scripts"
    
    experiment_desc = {
        "job_name": "cnn_supervised_patience50_1509",
        "environment": "MRAD",
        "script_name": "main.py",
        "args": {
            "data_type" : ['raw'],
            "model_name": ['cnn_128'],
            "val_methods": [["first_window"]],
            "learning_rate" : [0.0003],
            "datasets_names": [['supervised']],#[['OPPORTUNITY'], ['IOPS'], ['SVDB'], ['Daphnet'], ['MGAB'], ['MITDB'], ['Occupancy'], ['ECG'], ['GHL'], ['SensorScope'], ['SMD'], ['KDD21'], ['NAB'], ['Genesis'], ['YAHOO'], ['Dodgers']], 
            "train_ratio": [0.7],
            "n_steps": [1024],
            "total_timesteps": [720000], 
            "n_val" : [10],
            "epochs": [20], 
            "batch_size": [128],
            "patience": [5],
            "delta": [0.001],
            "seed": [42],
            "ressources": ['cpu']
        },
        "gpu_required": "1 if \"model_class\" == \"raw\" else 1"
    }
   
    # Analyse json
    saving_dir = os.path.join(current_dir, experiment_desc['job_name'])
    job_name = experiment_desc["job_name"] # used to create the save_dir in logs
    environment = experiment_desc["environment"]
    script_name = experiment_desc["script_name"]
    args = experiment_desc["args"]
    args_saving_path = 'results'
    arg_names = list(args.keys())
    arg_values = list(args.values())
    gpu_required = experiment_desc["gpu_required"]

    if args["ressources"] == ['gpu']:
        template = sh_templates['cleps_gpu']
    else:
        template = sh_templates['cleps_cpu']

    # Generate all possible combinations of arguments
    combinations = list(itertools.product(*arg_values))

    # Create the commands
    jobs = set()

    for combination in combinations:
        cmd = f"{script_name}"
        job_name_full = experiment_desc["job_name"]

        for name, value in zip(arg_names, combination):
            if name == 'val_methods':
                value_cmd = " ".join(value)
                cmd += f" --{name} {value_cmd}"

            elif name == 'datasets_names':
                value_name = "_".join(value)
                value_cmd = " ".join(value)
                cmd += f" --{name} {value_cmd}"
                job_name_full += f"_{value_name}"
            
            else:
                cmd += f" --{name} {value}"
                job_name_full += f"_{value}"

            if isinstance(gpu_required, str) and name in gpu_required:
                gpu_required = int(
                    eval(gpu_required.replace(name, str(value))))

        # Create saving dir if doesn't exist
        if not os.path.exists(saving_dir):
            os.makedirs(saving_dir)

        # Write the .sh file
        with open(os.path.join(saving_dir, f'{job_name_full}.sh'), 'w') as rsh:
            rsh.write(template.format(job_name_full, args_saving_path, job_name,
                      args_saving_path, job_name, environment, cmd))

        jobs.add(job_name_full)

    # Create sh file to conduct all experiments

    run_all_sh = ""
    jobs = list(jobs)
    jobs.sort()
    
    for job in jobs:
        run_all_sh += f"sbatch {os.path.join(saving_dir, f'{job}.sh')}\n"

    with open(os.path.join(saving_dir, f'conduct_{experiment_desc["job_name"]}.sh'), 'w') as rsh:
        rsh.write(run_all_sh)


if __name__ == "__main__":
    
    main()