import os
os.environ['PYTHONHASHSEED'] = str(1)
import multiprocessing
import subprocess
from pathlib import Path
from itertools import product
from collections import defaultdict
import re

import click

_CPU_COUNT = multiprocessing.cpu_count() - 1

# print(_CPU_COUNT) -> There are 95 CPUs

def _find_named_configs():
    configs = defaultdict(list)
    for c in Path("configs/").glob("**/*.yaml"):
        parent = str(c.relative_to("configs/").parent)
        name = c.stem
        if parent == ".":
            parent = None
        configs[parent].append(name)
    return configs


_NAMED_CONFIGS = _find_named_configs()

def _get_ingredient_from_mask(mask):
    if "/" in mask:
        return mask.split("/")
    return None, mask


def _validate_config_mask(ctx, param, values):
    for v in values:
        ingredient, _ = _get_ingredient_from_mask(v)
        if ingredient not in _NAMED_CONFIGS:
            raise click.BadParameter(
                str("Invalid ingredient '{}'. Valid ingredients are: {})".format(ingredient, list(_NAMED_CONFIGS.keys())))
            )
    return values


def _filter_configs(configs, mask):
    ingredient, mask = _get_ingredient_from_mask(mask)
    regex = re.compile(mask)
    configs[ingredient] = list(filter(regex.search, configs[ingredient]))
    return configs


def work(cmd):
    cmd = cmd.split(" ")
    return subprocess.call(cmd, shell=False)


@click.command()
@click.option("--seeds", default=3, show_default=True, help="How many seeds to run")
@click.option(
    "--cpus",
    default=_CPU_COUNT,
    show_default=True,
    help="How many processes to run in parallel",
)
@click.option(
    "--config-mask",
    "-c",
    multiple=True,
    callback=_validate_config_mask,
    help="Regex mask to filter configs/. Ingredient separator with forward slash \
    '/'. E.g. 'algorithm/rware*'. By default all configs found are used.",
)
@click.option('--use_comm', '-uc', default=False, is_flag=True, help="Use communication")
@click.option('--use_mem', '-um', default=False, is_flag=True, help="Use memory")
@click.option('--share_model', '-sm', default=False, is_flag=True, help="Share model")
@click.option('--train_cifar', '-tc', default=False, is_flag=True, help="Train CIFAR")

def main(seeds, cpus, config_mask, use_comm, use_mem, share_model, train_cifar):
    pool = multiprocessing.Pool(processes=cpus)

    configs = _NAMED_CONFIGS

    for mask in config_mask:
        configs = _filter_configs(configs, mask)
    configs = [[str("{}.{}").format(k, i) if k else str(i) for i in v] for k, v in configs.items()]
    configs += [[str("seed={}").format(seed) for seed in range(seeds)]]

    click.echo("Running following combinations: ")
    click.echo(click.style(" X ", fg="red", bold=True).join([str(s) for s in configs]))

    configs = list(product(*configs))
    if len(configs) == 0:
        click.echo("No valid combinations. Aborted!")
        exit(1)

    click.confirm(
        str("There are {} combinations of configurations. Up to {} will run in parallel. Continue?").format(click.style(str(len(configs)), fg='red'), cpus),
        abort=True,
    )
    if(use_comm):
        if(use_mem):
            if(share_model):
                configs = [
                    # "python3 train_comm_mem.py -u with dummy_vecenv=True " + " ".join(c) for c in configs
                    "python3 train_comm_mem_shared.py -u with dummy_vecenv=False " + " ".join(c) for c in configs
                ]
            else:
                configs = [
                    # "python3 train_comm_mem.py -u with dummy_vecenv=True " + " ".join(c) for c in configs
                    "python3 train_comm_mem.py -u with dummy_vecenv=False " + " ".join(c) for c in configs
                ]
        else:
            if(share_model):
                configs = [
                    # "python3 train_comm.py -u with dummy_vecenv=True " + " ".join(c) for c in configs
                    "python3 train_comm_shared.py -u with dummy_vecenv=False " + " ".join(c) for c in configs
                ]
            else:
                if(train_cifar):
                    configs = [
                        # "python3 train_comm.py -u with dummy_vecenv=True " + " ".join(c) for c in configs
                        "python3 train_comm_cifar.py -u with dummy_vecenv=False " + " ".join(c) for c in configs
                    ]
                else:
                    configs = [
                        # "python3 train_comm.py -u with dummy_vecenv=True " + " ".join(c) for c in configs
                        "python train_comm.py -u with dummy_vecenv=False " + " ".join(c) for c in configs
                    ]
    else:
        if(share_model):
            configs = [
                # "python3 train.py -u with dummy_vecenv=True " + " ".join(c) for c in configs
                "python3 train_shared.py -u with dummy_vecenv=False " + " ".join(c) for c in configs
            ]
        else:
            if(train_cifar):
                configs = [
                        # "python3 train.py -u with dummy_vecenv=True " + " ".join(c) for c in configs
                        "python3 train_cifar.py -u with dummy_vecenv=False " + " ".join(c) for c in configs
                    ]
            else:
                configs = [
                    # "python3 train.py -u with dummy_vecenv=True " + " ".join(c) for c in configs
                    "python3 train.py -u with dummy_vecenv=False " + " ".join(c) for c in configs
                ]

    print(pool.map(work, configs))


if __name__ == "__main__":
    main()
