from tempfile import NamedTemporaryFile
import datetime
import anymarkup
from iterexp_utils import product
from pathlib import Path
import asyncio


def grid(*items):
    items = list(items)
    for i in range(len(items)):
        if isinstance(items[i], dict):
            items[i] = grid(*[[{k: r} for r in v] if isinstance(v, (list, tuple)) else [{k: v}] for k, v in items[i].items()])
    assignments = list(product(*items))

    def combine(assignment):
        w = {}
        for p in assignment:
            w.update(p)
        return w

    assignments = [combine(assignment) for assignment in assignments]
    return assignments


class Task:
    def __init__(self, log_dir, template, mappings, *, shell='bash', pool_size=None):
        self.shell = shell
        self.template = template
        self.pool_size = len(mappings)
        if pool_size is not None:
            self.pool_size = min(self.pool_size, pool_size)

        uid = datetime.datetime.now().strftime('%y%m%d-%H%M%S')
        self.log_dir = Path(log_dir.format(uid=uid)).expanduser()
        self.log_dir.mkdir(exist_ok=False, parents=True)
        anymarkup.serialize_file(mappings, self.log_dir / '.status.json5')

        self.mappings = mappings
        self.scripts = scripts = [template.format(**mapping, log_dir=str(self.log_dir)) for mapping in mappings]

        semaphore = asyncio.Semaphore(self.pool_size)

        async def execute(script):
            async with semaphore:
                with NamedTemporaryFile() as fp:
                    fp.write(script.encode())
                    fp.flush()
                    proc = await asyncio.create_subprocess_exec(
                        *shell.split(' '), fp.name,
                        stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
                    await proc.wait()
                    return proc

        self.tasks = [asyncio.create_task(execute(script)) for script in scripts]

    def __repr__(self):
        return f'Task(log_dir = {str(self.log_dir)}, pool = {self.pool_size})'


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--template', help='template', required=True)
    parser.add_argument('--map', help='replace arguments', nargs=2, action='append', metavar=('PATH', 'VALUE'),
                        default=[])
    parser.add_argument('--log_dir', help='the directory to logs', default='/tmp')
    parser.add_argument('--pool_size', '-p', help='size of process pool, default to # total jobs', default=0, type=int)

    args = parser.parse_args()
    task = Task(args.template, args.log_dir, args.map, pool_size=args.pool_size)
    print(task)


if __name__ == '__main__':
    main()


__all__ = ['Task', 'grid']
