import glob
import os
from concurrent.futures import ProcessPoolExecutor

import invoke
import slackweb
from loguru import logger

from utils import WandbDataLoader


@invoke.task
def run(
    c: invoke.Context,
    args: str = "",
    gpu: int = 4,
) -> None:
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
    with c.cd("./src/"):
        c.run(f"poetry run python train.py {args}")
        c.run(f"poetry run python tree_weight_matching.py {args}")
        c.run(f"poetry run python tree_activation_matching.py {args}")
        c.run(f"poetry run python mlp_weight_matching.py {args}")
        c.run(f"poetry run python evaluation.py {args}")


def _unitrun(args):
    (
        c,
        dataset,
        gpu,
        n_tree,
        depth,
        learning_rate,
        seed_a,
        seed_b,
        split,
        wandb_project,
    ) = args
    args_str = f"dataset={dataset} depth={depth} learning_rate={learning_rate} n_tree={n_tree} seed_a={seed_a} seed_b={seed_b} split={split} wandb_project={wandb_project}"
    run(c, args_str, gpu)


@invoke.task
def exact_run(
    c: invoke.Context,
    gpus: str = "4",
) -> None:
    counter = 0
    gpus = gpus.split(",")

    yaml_paths = glob.glob("src/config/dataset/*.yaml")
    dataset_candidates = sorted(
        [os.path.splitext(os.path.basename(filepath))[0] for filepath in yaml_paths]
    )
    dataset_candidates.remove("mnist")
    dataset_candidates.remove("cifar10")

    grid_args_dict = {}
    for depth in sorted(c.exact.depth, reverse=True):
        grid_args = []
        for split in c.exact.split:
            for seed_a, seed_b in zip(c.exact.seed_a, c.exact.seed_b):
                for dataset in dataset_candidates:
                    for n_tree in sorted(c.exact.n_tree, reverse=True):
                        for learning_rate in c.exact.learning_rate:
                            gpu = int(gpus[counter % len(gpus)])
                            grid_args.append(
                                (
                                    c,
                                    dataset,
                                    gpu,
                                    n_tree,
                                    depth,
                                    learning_rate,
                                    seed_a,
                                    seed_b,
                                    split,
                                    c.exact.wandb_project,
                                )
                            )
                            counter += 1
        grid_args_dict[depth] = grid_args

    n_workers_dict = {1: 10, 2: 10, 3: 10}
    for depth in sorted(c.exact.depth):
        with ProcessPoolExecutor(
            max_workers=len(gpus) * n_workers_dict[depth]
        ) as executor:
            executor.map(_unitrun, grid_args_dict[depth])
        executor.shutdown(wait=True)
    send_slack(c, message="exact-run is completed")


@invoke.task
def deep_run(
    c: invoke.Context,
    gpus: str = "4",
) -> None:
    counter = 0
    gpus = gpus.split(",")

    yaml_paths = glob.glob("src/config/dataset/*.yaml")
    dataset_candidates = sorted(
        [os.path.splitext(os.path.basename(filepath))[0] for filepath in yaml_paths]
    )
    dataset_candidates.remove("mnist")
    dataset_candidates.remove("cifar10")

    grid_args_dict = {}
    for depth in sorted(c.deep.depth, reverse=True):
        grid_args = []
        for split in c.deep.split:
            for seed_a, seed_b in zip(c.deep.seed_a, c.deep.seed_b):
                for dataset in dataset_candidates:
                    for n_tree in sorted(c.deep.n_tree, reverse=True):
                        for learning_rate in c.deep.learning_rate:
                            gpu = int(gpus[counter % len(gpus)])
                            grid_args.append(
                                (
                                    c,
                                    dataset,
                                    gpu,
                                    n_tree,
                                    depth,
                                    learning_rate,
                                    seed_a,
                                    seed_b,
                                    split,
                                    c.deep.wandb_project,
                                )
                            )
                            counter += 1
        grid_args_dict[depth] = grid_args

    n_workers_dict = {2: 10, 4: 10, 8: 5}
    for depth in sorted(c.deep.depth, reverse=True):
        with ProcessPoolExecutor(
            max_workers=len(gpus) * n_workers_dict[depth]
        ) as executor:
            executor.map(_unitrun, grid_args_dict[depth])
        executor.shutdown(wait=True)
    send_slack(c, message="deep-run is completed")


@invoke.task
def debug(
    c: invoke.Context,
    gpu: int = 4,
    disable_cache: bool = False,
    wandb: bool = False,
    only_modelcheck: bool = False,
):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
    for depth in (1, 2):
        args = f"epochs=1 seed_a=1 seed_b=1 depth={depth} n_tree=64 wandb={wandb} disable_cache={disable_cache}"
        with c.cd("./src/"):
            if only_modelcheck:
                c.run(f"poetry run python debug.py {args}")
            else:
                c.run(f"poetry run python train.py {args}")
                c.run(f"poetry run python tree_weight_matching.py {args}")
                c.run(f"poetry run python tree_activation_matching.py {args}")
                c.run(f"poetry run python mlp_weight_matching.py {args}")
                c.run(f"poetry run python debug.py {args}")
                c.run(f"poetry run python evaluation.py {args}")

    send_slack(c, message="debug-run is completed")


@invoke.task
def kill(c: invoke.Context):
    c.run("pgrep -f tree-re-basin | xargs kill")


@invoke.task
def send_slack(c: invoke.Context, message: str):
    slack = slackweb.Slack(url=c.slack)
    slack.notify(text=message)


@invoke.task
def summarize_wandb(c: invoke.Context):
    wandb_loader = WandbDataLoader("./src/wandb/*/*.wandb")
    result_df = wandb_loader.load_data()
    logger.info(result_df)
    for project in result_df["project"].unique():
        result_df[result_df["project"] == project].reset_index(drop=True).to_pickle(
            f"{project}.pkl"
        )


@invoke.task
def clear_wandb(c: invoke.Context):
    c.run("find src/wandb -mindepth 1 -delete")
