from pathlib import Path
from typing import Tuple, List

import torch

import algorithms
from algorithms.convergence_algorithms.typing import BoundedEvaluatedSamplerIdentifiableSpace
from compute_result.factory import StoreTypes
from utils.logger import create_logger, create_file_log_path, create_splunk_path
from run_options import (
    GRAPH_OPTION,
    RUN_NAME_OPTION,
    STARTING_POINT_OPTION,
    CONFIG_FROM_RUN_NAME_OPTION,
    ANALYZE_OPTION,
    ADDITIONAL_LOGS_OPTION,
    USE_FILE_LOG_OPTION,
    USE_SPLUNK_LOG_OPTION,
    AlgorithmCommand,
    space_instance,
    RESULT_PATH_OPTION,
)
from scripts.run_algorithm import run
from utils.algorithms_data import Algorithms
from utils.python import timestamp_file_signature


@GRAPH_OPTION
@RUN_NAME_OPTION
@STARTING_POINT_OPTION
@CONFIG_FROM_RUN_NAME_OPTION
@ANALYZE_OPTION
@ADDITIONAL_LOGS_OPTION
@USE_FILE_LOG_OPTION
@USE_SPLUNK_LOG_OPTION
@RESULT_PATH_OPTION
@space_instance
def main(
    algorithm: Algorithms,
    graph: bool,
    run_name: str,
    start: int,
    use_config: List[str],
    device: int,
    space: BoundedEvaluatedSamplerIdentifiableSpace,
    analyze: bool,
    alog: bool,
    log_file: bool,
    log_splunk: bool,
    results_path: Tuple[Path, StoreTypes],
    **kwargs,
):
    torch.set_default_dtype(torch.float64)
    run_signature = timestamp_file_signature()
    base_dir = Path()
    normal_logs_path = (
        (
            create_file_log_path(base_dir, algorithm, run_name)
            / str(space).split(",")[0]
            / f"logs_for_main-{run_signature}"
        )
        if log_file
        else None
    )
    splunk_logs_path = (
        create_splunk_path(base_dir, algorithm, run_name, run_signature) if log_splunk else None
    )

    logger = create_logger(normal_logs_path, splunk_logs_path, run_name, algorithm, space)
    run(
        algorithm,
        space,
        kwargs.copy(),
        use_config,
        torch.tensor([start] * space.dimension, dtype=torch.float64, device=device),
        device,
        output_bar=False,
        output_graph=graph,
        analyze_run=analyze,
        additional_logs=alog,
        run_name=run_name,
        logger=logger,
        grad_error=True,
        results_path=results_path,
    )


if __name__ == "__main__":
    cli = AlgorithmCommand(algorithms, main)
    cli()
