import os
import tempfile
from concurrent.futures import ProcessPoolExecutor, as_completed
from inspect import ArgSpec, getfullargspec
from logging import Formatter, StreamHandler, getLogger
from unittest.mock import patch

import invoke
import slackweb

log_fmt = Formatter(
    "%(asctime)s %(name)s L%(lineno)d [%(levelname)s][%(funcName)s] %(message)s "
)
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel("INFO")
handler.setFormatter(log_fmt)

logger.setLevel("INFO")
logger.addHandler(handler)
logger.propagate = False

REPOSITORY_ROOT = os.path.dirname(__file__)


def fix_invoke_annotations() -> None:
    def patched_inspect_getargspec(func):
        spec = getfullargspec(func)
        return ArgSpec(*spec[0:4])

    org_task_argspec = invoke.tasks.Task.argspec

    def patched_task_argspec(*args, **kwargs):
        with patch(target="inspect.getargspec", new=patched_inspect_getargspec):
            return org_task_argspec(*args, **kwargs)

    invoke.tasks.Task.argspec = patched_task_argspec


fix_invoke_annotations()

# --------------------------


@invoke.task
def setup(c: invoke.Context):
    c.run("poetry install --no-root")
    c.run("poetry run pre-commit install")
    with tempfile.TemporaryDirectory() as tmp_d:
        with c.cd(tmp_d):
            c.run(
                "wget http://persoal.citius.usc.es/manuel.fernandez.delgado/papers/jmlr/data.tar.gz"
            )
            c.run(f"tar -xvzf data.tar.gz -C {REPOSITORY_ROOT}/tree_ntk/data")
    logger.info("setup done")


@invoke.task(default=True)
def run(
    c: invoke.Context,
    model: str = "tree_ntk",
    mode: str = "train",
    max_tot: int = 5000,
    alpha: float = 1.0,
):
    with c.cd("./tree_ntk"):
        c.run(
            f"poetry run python main.py -max_tot {max_tot} -model {model} -alpha {alpha} -mode {mode}"
        )


def train(
    c: invoke.Context,
    alpha: float,
    model: str,
    mode: str,
    max_tot: int,
    reg_coef: float,
    name: str,
):
    result = c.run(
        f"poetry run python main.py -max_tot {max_tot} -model {model} -alpha {alpha} -mode {mode} -reg_coef {reg_coef} -name {name}"
    )
    return result.ok


@invoke.task
def multi_run(
    c: invoke.Context,
    max_tot: int = 5000,
    reg_coef: float = 1e-8,
    name: str = "test",
    only_train: bool = False,
):
    if not only_train:
        process_executor = ProcessPoolExecutor(max_workers=32)
        with c.cd("./tree_ntk"):
            jobs = []
            for alpha in [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0]:
                jobs.append(
                    process_executor.submit(
                        train,
                        c,
                        alpha,
                        "tree_ntk",
                        "kernel",
                        max_tot,
                        reg_coef,
                        name,
                    )
                )
                jobs.append(
                    process_executor.submit(
                        train,
                        c,
                        alpha,
                        "asymtree_ntk",
                        "kernel",
                        max_tot,
                        reg_coef,
                        name,
                    )
                )
                jobs.append(
                    process_executor.submit(
                        train,
                        c,
                        alpha,
                        "inf_asymtree_ntk",
                        "kernel",
                        max_tot,
                        reg_coef,
                        name,
                    )
                )

            for future in as_completed(jobs):
                assert future.result()
                jobs.remove(future)

    with c.cd("./tree_ntk"):
        jobs = []
        for alpha in [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0]:
            train(c, alpha, "tree_ntk", "train", max_tot, reg_coef, name)
            train(c, alpha, "asymtree_ntk", "train", max_tot, reg_coef, name)
            train(c, alpha, "inf_asymtree_ntk", "train", max_tot, reg_coef, name)
    # slack = slackweb.Slack(
    #     url="....................."
    # )
    # slack.notify(text="Multi-run is Finished")


@invoke.task
def stop_all(c: invoke.Context):
    c.run("pgrep -f python | xargs kill -9")
