import os
import sys
import subprocess
import pandas as pd
import glob
import json
import itertools


# Disable
def blockPrint():
    sys.stdout = open(os.devnull, 'w')


# Restore
def enablePrint():
    sys.stdout = sys.__stdout__


def create_df_results(root_dir):
    experiment_results = []
    for json_file_name in glob.glob(os.path.join(root_dir, '**', 'output.json'), recursive=True):
        with open(json_file_name) as json_file:
            json_exper = json.load(json_file)
            json_exper['json_filename'] = json_file_name
            experiment_results.append(json_exper)

    return pd.DataFrame.from_dict(experiment_results)


def create_experiment_job_file(run_filename, dataset_dict, model_dict, optim_dict, trainer_dict,
                               yaml_file, seed_list,
                               root_dir=''):

    with open(run_filename, 'w'):
        pass
    permutations_dataset = [{}]
    if len(dataset_dict) > 0:
        keys, values = zip(*dataset_dict.items())
        permutations_dataset = [dict(zip(keys, v)) for v in itertools.product(*values)]

    permutations_model = [{}]
    if len(model_dict) > 0:
        keys, values = zip(*model_dict.items())
        permutations_model = [dict(zip(keys, v)) for v in itertools.product(*values)]

    permutations_optim = [{}]
    if len(optim_dict) > 0:
        keys, values = zip(*optim_dict.items())
        permutations_optim = [dict(zip(keys, v)) for v in itertools.product(*values)]

    perm_idx = 0
    if isinstance(trainer_dict, dict):
        trainer_dict = '+'.join(
            [f'{k}=\'{v}\'' if isinstance(v, list) else f'{k}={v}' for k, v in trainer_dict.items()])
    else:
        trainer_dict = ''
    for seed in seed_list:
        for d_dict in permutations_dataset:
            for m_dict in permutations_model:
                for o_dict in permutations_optim:
                    d = '+'.join([f'{k}=\'{v}\'' if isinstance(v, list) else f'{k}={v}' for k, v in d_dict.items()])

                    m = '+'.join([f'{k}=\'{v}\'' if isinstance(v, list) else f'{k}={v}' for k, v in m_dict.items()])
                    o = '+'.join([f'{k}=\'{v}\'' if isinstance(v, list) else f'{k}={v}' for k, v in o_dict.items()])

                    arguments = f"\narguments = \"main.py "
                    arguments += f"--dataset_file {yaml_file[0]} "
                    arguments += f"--model_file {yaml_file[1]} "
                    arguments += f"--trainer_file {yaml_file[2]} "
                    if len(d) > 0: arguments += f"-d {d} "
                    if len(m) > 0: arguments += f"-m {m} "
                    if len(o) > 0: arguments += f"-o {o} "
                    if len(trainer_dict) > 0: arguments += f"-t {trainer_dict} "
                    if len(root_dir) > 0:  arguments += f"-r {root_dir} "
                    main_command = arguments.split('arguments = "')[1]
                    main_command += f"-s {seed}\n"

                    with open(run_filename, 'a') as f:
                        f.write(f"python {main_command}")

                    perm_idx += 1

    print(f'Number of permutations: {perm_idx}')

    return
