from pathlib import Path
from typing import Tuple, List

import torch
from beam.distributed.ray_dispatcher import RayDispatcher, RayClient

import algorithms
from algorithms.convergence_algorithms.typing import (
    BoundedEvaluatedSamplerIdentifiableSpace,
)
from compute_result.factory import StoreTypes
from run_options import (
    RESULT_PATH_OPTION,
    RUN_NAME_OPTION,
    GRAPH_OPTION,
    CONCURRENCY_OPTION,
    CONFIG_FROM_RUN_NAME_OPTION,
    IGNORE_GPU_OPTION,
    STARTING_POINT_OPTION,
    USE_FILE_LOG_OPTION,
    USE_SPLUNK_LOG_OPTION,
    ANALYZE_OPTION,
    ADDITIONAL_LOGS_OPTION,
    AlgorithmCommand,
    NETWORK_SIZE_OPTION,
    MODEL_OPTIMIZER_LR_OPTION,
    VALUE_OPTIMIZER_LR_OPTION,
    SET_OPTION,
    space_instance,
)
from scripts.run_algorithm import run
from utils.logger import create_logger, create_file_log_path, create_splunk_path
from utils.python import timestamp_file_signature

RayClient()


def run_ray_worker(
    algorithm_name: str,
    run_name: str,
    space: BoundedEvaluatedSamplerIdentifiableSpace,
    parameters: dict,
    use_config: List[str],
    starting_point: int,
    output_graph: bool,
    use_splunk_log: bool,
    use_file_log: bool,
    analyze_run: bool,
    additional_logs: bool,
    results_path: Tuple[Path, StoreTypes],
):
    torch.set_default_dtype(torch.float64)
    device = 0
    base_dir = Path(__file__).parent
    run_signature = timestamp_file_signature()

    normal_logs_path = (
        (
            create_file_log_path(base_dir, algorithm_name, run_name)
            / rf"logs_for_{algorithm_name}_{repr(space)}_parallel-{timestamp_file_signature()}"
        )
        if use_file_log
        else None
    )
    splunk_logs_path = (
        create_splunk_path(base_dir, algorithm_name, run_name, run_signature)
        if use_splunk_log
        else None
    )

    logger = create_logger(
        normal_logs_path, splunk_logs_path, run_name, algorithm_name, space
    )

    try:
        run(
            algorithm_name,
            space,
            parameters,
            use_config,
            torch.tensor(
                [starting_point] * space.dimension, dtype=torch.float64, device=device
            ),
            device,
            output_bar=False,
            output_graph=output_graph,
            additional_logs=additional_logs,
            analyze_run=analyze_run,
            run_name=run_name,
            logger=logger,
            grad_error=True,
            results_path=results_path,
        )
    except BaseException as e:
        print(e)
        logger.exception(e)


@CONCURRENCY_OPTION
@GRAPH_OPTION
@RUN_NAME_OPTION
@CONFIG_FROM_RUN_NAME_OPTION
@IGNORE_GPU_OPTION
@STARTING_POINT_OPTION
@USE_FILE_LOG_OPTION
@USE_SPLUNK_LOG_OPTION
@ANALYZE_OPTION
@ADDITIONAL_LOGS_OPTION
@RESULT_PATH_OPTION
@NETWORK_SIZE_OPTION
@MODEL_OPTIMIZER_LR_OPTION
@VALUE_OPTIMIZER_LR_OPTION
@SET_OPTION
@space_instance
def run_on_ray(
    algorithm: str,
    concurrency: int,
    graph: bool,
    run_name: str,
    use_config: List[str],
    ignore_gpu: Tuple[int],
    start: int,
    log_file: bool,
    log_splunk: bool,
    analyze: bool,
    alog: bool,
    results_path: Tuple[Path, StoreTypes],
    net_size,
    model_lr: float,
    grad_lr: float,
    setn: List[Tuple[str, str]],
    space: List[BoundedEvaluatedSamplerIdentifiableSpace],
    **kwargs,
):
    space = space if isinstance(space, list) else [space]
    kwargs.update({"layers": net_size, "model_lr": model_lr, "value_lr": grad_lr})
    kwargs.update(dict(setn))
    print(f"Working on {len(space)}")

    workers = [
        RayDispatcher(run_ray_worker, remote_kwargs={"num_gpus": 1 / concurrency})
        for _ in range(len(space))
    ]
    results = [
        worker(
            algorithm,
            run_name,
            space_dim,
            kwargs,
            use_config,
            start,
            graph,
            log_splunk,
            log_file,
            analyze,
            alog,
            results_path,
        )
        for worker, space_dim in zip(workers, space)
    ]
    print([res.value for res in results])


if __name__ == "__main__":
    cli = AlgorithmCommand(algorithms, run_on_ray)
    cli()
