from pathlib import Path
from typing import Tuple, List

import torch

import algorithms
from compute_result.factory import StoreTypes
from utils.logger import create_logger
from problems.benchmarks import find_problems_to_run
from problems.types import Benchmarks
from run_options import (
    GRAPH_OPTION,
    BENCHMARK_OPTION,
    RUN_NAME_OPTION,
    BUDGET_OPTION,
    CONFIG_FROM_RUN_NAME_OPTION,
    PART_OPTION,
    USE_SPLUNK_LOG_OPTION,
    USE_FILE_LOG_OPTION,
    STARTING_POINT_OPTION,
    ANALYZE_OPTION,
    ADDITIONAL_LOGS_OPTION,
    FUNCTION_NUMBER_OPTION,
    FUNCTION_DIM_OPTION,
    FUNCTION_INSTANCE_OPTION,
    AlgorithmCommand,
    DEVICE_OPTION,
    RESULT_PATH_OPTION,
    NETWORK_SIZE_OPTION,
    MODEL_OPTIMIZER_LR_OPTION,
    VALUE_OPTIMIZER_LR_OPTION,
    SET_OPTION,
)
from scripts.run_algorithm import run
from utils.algorithms_data import Algorithms
from utils.python import timestamp_file_signature


@GRAPH_OPTION
@BENCHMARK_OPTION(True)
@FUNCTION_NUMBER_OPTION()
@FUNCTION_DIM_OPTION()
@FUNCTION_INSTANCE_OPTION
@RUN_NAME_OPTION
@BUDGET_OPTION
@STARTING_POINT_OPTION
@PART_OPTION
@CONFIG_FROM_RUN_NAME_OPTION
@DEVICE_OPTION
@ANALYZE_OPTION
@ADDITIONAL_LOGS_OPTION
@USE_FILE_LOG_OPTION
@USE_SPLUNK_LOG_OPTION
@RESULT_PATH_OPTION
@NETWORK_SIZE_OPTION
@MODEL_OPTIMIZER_LR_OPTION
@VALUE_OPTIMIZER_LR_OPTION
@SET_OPTION
def main(
    algorithm: Algorithms,
    benchmark: Benchmarks,
    func_num: List[int],
    func_dim: List[int],
    func_inst: List[int],
    graph: bool,
    run_name: str,
    budget: int,
    start: int,
    part: Tuple[int, int],
    use_config: List[str],
    device: int,
    analyze: bool,
    alog: bool,
    log_file: bool,
    log_splunk: bool,
    results_path: Tuple[Path, StoreTypes],
    net_size,
    model_lr: float,
    grad_lr: float,
    setn: List[Tuple[str, str]],
    **kwargs,
):
    kwargs.update({"layers": net_size, "model_lr": model_lr, "value_lr": grad_lr})
    kwargs.update(dict(setn))
    torch.set_default_dtype(torch.float64)
    parts = list(part) if isinstance(part[0], tuple) else [part]
    spaces_to_run = find_problems_to_run(
        benchmark,
        budget,
        func_num=func_num,
        parts=parts,
        func_dim=func_dim,
        func_inst=func_inst,
    )

    run_signature = timestamp_file_signature()
    for space in spaces_to_run:
        base_log_path = Path("logs") / algorithm / run_name

        normal_logs_path = (
            (base_log_path / "normal" / f"logs_for_main-{timestamp_file_signature()}")
            if log_file
            else None
        )
        splunk_logs_path = (
            (
                base_log_path
                / "splunk"
                / f"{run_name}_splunk_logs_for_main_{run_signature}"
            )
            if log_splunk
            else None
        )

        logger = create_logger(
            normal_logs_path, splunk_logs_path, run_name, algorithm, space
        )
        run(
            algorithm,
            space,
            kwargs.copy(),
            use_config,
            torch.tensor([start] * space.dimension, dtype=torch.float64, device=device),
            device,
            output_bar=False,
            output_graph=graph,
            analyze_run=analyze,
            additional_logs=alog,
            run_name=run_name,
            logger=logger,
            results_path=results_path,
        )


if __name__ == "__main__":
    cli = AlgorithmCommand(algorithms, main)
    cli()
