from pathlib import Path
from typing import Tuple, List

import torch
from beam.distributed.ray_dispatcher import RayDispatcher, RayClient

import algorithms
from compute_result.factory import StoreTypes
from utils.logger import create_logger, create_file_log_path, create_splunk_path
from problems.benchmarks import find_problems_to_run
from problems.types import Benchmarks
from run_options import (
    RESULT_PATH_OPTION,
    RUN_NAME_OPTION,
    BUDGET_OPTION,
    GRAPH_OPTION,
    BENCHMARK_OPTION,
    CONCURRENCY_OPTION,
    PART_OPTION,
    CONFIG_FROM_RUN_NAME_OPTION,
    IGNORE_GPU_OPTION,
    STARTING_POINT_OPTION,
    USE_FILE_LOG_OPTION,
    USE_SPLUNK_LOG_OPTION,
    ANALYZE_OPTION,
    ADDITIONAL_LOGS_OPTION,
    FUNCTION_NUMBER_OPTION,
    AlgorithmCommand,
    NETWORK_SIZE_OPTION,
    MODEL_OPTIMIZER_LR_OPTION,
    VALUE_OPTIMIZER_LR_OPTION,
    SET_OPTION,
    FUNCTION_DIM_OPTION,
)
from scripts.run_algorithm import run
from utils.python import timestamp_file_signature


RayClient()


def run_ray_worker(
    algorithm_name: str,
    run_name: str,
    space,
    parameters: dict,
    use_config: List[str],
    starting_point: int,
    output_graph: bool,
    use_splunk_log: bool,
    use_file_log: bool,
    analyze_run: bool,
    additional_logs: bool,
    results_path: Tuple[Path, StoreTypes],
):
    torch.set_default_dtype(torch.float64)
    device = 0
    base_dir = Path(__file__).parent
    run_signature = timestamp_file_signature()

    normal_logs_path = (
        (
            create_file_log_path(base_dir, algorithm_name, run_name)
            / rf"logs_for_{algorithm_name}_{repr(space)}_parallel-{timestamp_file_signature()}"
        )
        if use_file_log
        else None
    )
    splunk_logs_path = (
        create_splunk_path(base_dir, algorithm_name, run_name, run_signature)
        if use_splunk_log
        else None
    )

    logger = create_logger(
        normal_logs_path, splunk_logs_path, run_name, algorithm_name, space
    )

    try:
        run(
            algorithm_name,
            space,
            parameters,
            use_config,
            torch.tensor(
                [starting_point] * space.dimension, dtype=torch.float64, device=device
            ),
            device,
            output_bar=False,
            output_graph=output_graph,
            additional_logs=additional_logs,
            analyze_run=analyze_run,
            run_name=run_name,
            logger=logger,
            results_path=results_path,
        )
    except BaseException as e:
        print(e)
        logger.exception(e)


@CONCURRENCY_OPTION
@BENCHMARK_OPTION(True)
@FUNCTION_NUMBER_OPTION()
@FUNCTION_DIM_OPTION()
@GRAPH_OPTION
@RUN_NAME_OPTION
@BUDGET_OPTION
@PART_OPTION
@CONFIG_FROM_RUN_NAME_OPTION
@IGNORE_GPU_OPTION
@STARTING_POINT_OPTION
@USE_FILE_LOG_OPTION
@USE_SPLUNK_LOG_OPTION
@ANALYZE_OPTION
@ADDITIONAL_LOGS_OPTION
@RESULT_PATH_OPTION
@NETWORK_SIZE_OPTION
@MODEL_OPTIMIZER_LR_OPTION
@VALUE_OPTIMIZER_LR_OPTION
@SET_OPTION
def run_on_ray(
    algorithm: str,
    concurrency: int,
    benchmark: Benchmarks,
    func_num: List[int],
    func_dim: List[int],
    graph: bool,
    run_name: str,
    budget: int,
    part: Tuple[int, int],
    use_config: List[str],
    ignore_gpu: Tuple[int],
    start: int,
    log_file: bool,
    log_splunk: bool,
    analyze: bool,
    alog: bool,
    results_path: Tuple[Path, StoreTypes],
    net_size,
    model_lr: float,
    grad_lr: float,
    setn: List[Tuple[str, str]],
    **kwargs,
):
    kwargs.update({"layers": net_size, "model_lr": model_lr, "value_lr": grad_lr})
    kwargs.update(dict(setn))
    parts = list(part) if isinstance(part[0], tuple) else [part]
    spaces_to_run = find_problems_to_run(
        benchmark,
        budget,
        func_num=func_num,
        parts=parts,
        func_dim=func_dim,
    )
    print(f"Working on {len(spaces_to_run)}")

    workers = [
        RayDispatcher(run_ray_worker, remote_kwargs={"num_gpus": 1 / concurrency})
        for _ in range(len(spaces_to_run))
    ]
    results = [
        worker(
            algorithm,
            run_name,
            space,
            kwargs,
            use_config,
            start,
            graph,
            log_splunk,
            log_file,
            analyze,
            alog,
            results_path,
        )
        for worker, space in zip(workers, spaces_to_run)
    ]
    print([res.value for res in results])


if __name__ == "__main__":
    cli = AlgorithmCommand(algorithms, run_on_ray)
    cli()
