from pathlib import Path
from typing import Tuple, List

import torch
from torch.multiprocessing import Queue

import algorithms
from compute_result.factory import StoreTypes
from utils.logger import create_logger, create_file_log_path, create_splunk_path
from problems.benchmarks import find_problems_to_run
from problems.types import Benchmarks
from run_options import (
    RESULT_PATH_OPTION,
    RUN_NAME_OPTION,
    BUDGET_OPTION,
    GRAPH_OPTION,
    BENCHMARK_OPTION,
    CONCURRENCY_OPTION,
    PART_OPTION,
    CONFIG_FROM_RUN_NAME_OPTION,
    IGNORE_GPU_OPTION,
    STARTING_POINT_OPTION,
    USE_FILE_LOG_OPTION,
    USE_SPLUNK_LOG_OPTION,
    ANALYZE_OPTION,
    ADDITIONAL_LOGS_OPTION,
    FUNCTION_NUMBER_OPTION,
    AlgorithmCommand,
)
from scripts.run_algorithm import run
from utils.algorithms_data import Algorithms
from utils.python import timestamp_file_signature


def run_alg(
    rank: int,
    spaces_queue: Queue,
    algorithm_name: str,
    parameters: dict,
    use_config: List[str],
    starting_point: int,
    run_name: str,
    output_graph: bool,
    gpu_to_use: List[int],
    use_splunk_log: bool,
    use_file_log: bool,
    analyze_run: bool,
    additional_logs: bool,
    results_path: Tuple[Path, StoreTypes],
):
    torch.set_default_dtype(torch.float64)
    run_signature = timestamp_file_signature()
    while not spaces_queue.empty():
        space = spaces_queue.get()
        base_dir = Path()

        normal_logs_path = (
            (
                create_file_log_path(base_dir, algorithm_name, run_name)
                / rf"logs_for_{algorithm_name}_{repr(space)}_parallel-{timestamp_file_signature()}"
            )
            if use_file_log
            else None
        )
        splunk_logs_path = (
            create_splunk_path(base_dir, algorithm_name, run_name, run_signature)
            if use_splunk_log
            else None
        )

        logger = create_logger(
            normal_logs_path, splunk_logs_path, run_name, algorithm_name, space
        )
        device = gpu_to_use[rank]
        logger.info(f"Device {rank} {device}")
        try:
            run(
                algorithm_name,
                space,
                parameters,
                use_config,
                torch.tensor(
                    [starting_point] * space.dimension, dtype=torch.float64, device=device
                ),
                device,
                output_bar=False,
                output_graph=output_graph,
                analyze_run=analyze_run,
                additional_logs=additional_logs,
                run_name=run_name,
                logger=logger,
                results_path=results_path,
            )
        except Exception:
            logger.exception("An error occurred which stopped the algorithm")


@CONCURRENCY_OPTION
@BENCHMARK_OPTION(True)
@FUNCTION_NUMBER_OPTION()
@GRAPH_OPTION
@RUN_NAME_OPTION
@BUDGET_OPTION
@PART_OPTION
@CONFIG_FROM_RUN_NAME_OPTION
@IGNORE_GPU_OPTION
@STARTING_POINT_OPTION
@USE_FILE_LOG_OPTION
@USE_SPLUNK_LOG_OPTION
@ANALYZE_OPTION
@ADDITIONAL_LOGS_OPTION
@RESULT_PATH_OPTION
def run_parallel(
    algorithm: Algorithms,
    concurrency: int,
    benchmark: Benchmarks,
    func_num: List[int],
    graph: bool,
    run_name: str,
    budget: int,
    part: Tuple[int, int],
    use_config: List[str],
    ignore_gpu: Tuple[int],
    start: int,
    log_file: bool,
    log_splunk: bool,
    analyze: bool,
    alog: bool,
    results_path: Tuple[Path, StoreTypes],
    **kwargs,
):
    parts = list(part) if isinstance(part[0], tuple) else [part]

    torch.multiprocessing.set_start_method("spawn")
    spaces_queue = Queue()
    spaces_to_run = find_problems_to_run(benchmark, budget, func_num=func_num, parts=parts)
    print(f"Working on {len(spaces_to_run)}")
    for s in spaces_to_run:
        spaces_queue.put(s)

    pool = []
    gpu_to_use = list(set(range(torch.cuda.device_count())) - set(ignore_gpu))
    for gpu_index in range(concurrency):
        process = torch.multiprocessing.spawn(
            run_alg,
            args=(
                spaces_queue,
                algorithm,
                kwargs.copy(),
                use_config,
                start,
                run_name,
                graph,
                gpu_to_use,
                log_splunk,
                log_file,
                analyze,
                alog,
                results_path,
            ),
            nprocs=len(gpu_to_use),
            join=False,
        )
        pool += [process]
    for proc in pool:
        proc.join()


if __name__ == "__main__":
    cli = AlgorithmCommand(algorithms, run_parallel)
    cli()
