import functools
from pathlib import Path
from types import ModuleType
from typing import Callable, List, Optional, Tuple

import click
import torch
from click.core import MultiCommand, Context, Command, Option

import algorithms
from algorithms.convergence_algorithms.base import Algorithm
from compute_result.factory import StoreTypes
from compute_result.output_manager.base import AxisTypes
from compute_result.output_manager.image import ImageOutput
from compute_result.output_manager.tensorboard import TensorBoardOutput
from compute_result.output_manager.types import Outputs
from compute_result.output_manager.writer_factory import WriterFactory
from compute_result.typing import StatisticalAnalysisOptions
from problems.types import Benchmarks, Suites
from problems.utils import BUDGET
from utils.algorithms_data import Algorithms
from utils.dynamically_load_class import (
    find_subclasses,
    get_signature_parameters,
    Configurable,
    find_class_by_name,
)
from utils.python import PRIMITIVES, torch_uniform

BASE_PATH = Path(__file__).parent

PARAM_TYPE_FORMAT = "{0}_type"
RUN_NAME_AND_ALG_OPTION = lambda multiple=True: click.option(
    "-r",
    "--run",
    type=click.Tuple([Algorithms, str]),
    multiple=multiple,
    help="The run (algorithm and run name)",
)
RUN_WITH_NAME_OPTION = lambda multiple=True: click.option(
    "-r",
    "--run",
    type=click.Tuple([Algorithms, str, str]),
    multiple=multiple,
    help="The run (algorithm and run name) and a name for the graph",
)
SUITE_OPTION = click.option(
    "-s",
    "--suite",
    default=None,
    type=Suites,
    help="The suite",
)
RUN_NAME_OPTION = click.option(
    "-n",
    "--run_name",
    type=str,
    default="normal",
    help="The name of the run",
)
ALG_OPTION = click.option(
    "-a",
    "--algorithm",
    type=Algorithms,
    default=None,
    help="Which algorithm to check?",
)
BUDGET_OPTION = click.option(
    "--budget", type=int, default=BUDGET, help="The budget size for the environment"
)
GRAPH_OPTION = click.option(
    "-g",
    "--graph",
    type=bool,
    is_flag=True,
    default=False,
    help="Do you want to print the graph",
)
STATISTICAL_ANALYSIS_OPTION = click.option(
    "--analysis",
    type=StatisticalAnalysisOptions,
    default=StatisticalAnalysisOptions.MEAN,
    help="How to measure the results across environments",
)
GRAPH_NAME_OPTION = click.option(
    "-g",
    "--graph_name",
    type=str,
    default="",
    help="The graph name on which the data will be printed",
)
BENCHMARK_OPTION = lambda multiple=False: click.option(
    "-b",
    "--benchmark",
    type=Benchmarks,
    default=Benchmarks.COCO if not multiple else None,
    multiple=multiple,
    help="The name of the benchmark",
)
CONCURRENCY_OPTION = click.option(
    "-c",
    "--concurrency",
    type=int,
    default=2,
    help="How many process for each gpu in parallel",
)
PART_OPTION = click.option(
    "-p",
    "--part",
    type=(int, int),
    multiple=True,
    default=[(0, 1)],
    help="Do you want to split this job into parts (may help when the job is too big)",
)
CONFIG_FROM_RUN_NAME_OPTION = click.option(
    "--use_config",
    type=str,
    default=None,
    multiple=True,
    help="Do you want to use the default config for this run?",
)
IGNORE_GPU_OPTION = click.option(
    "--ignore_gpu",
    type=int,
    multiple=True,
    help="which gpu to ignore?",
)
STARTING_POINT_OPTION = click.option(
    "--start",
    type=int,
    default=0,
    help="from where to start?",
)
AXIS_TYPE_OPTION = click.option(
    "--axis_type",
    type=AxisTypes,
    default=AxisTypes.NORMAL,
    help="Which type of axis output do you want? (log/normal)",
)
USE_SPLUNK_LOG_OPTION = click.option(
    "--log_splunk",
    is_flag=True,
    type=bool,
    default=False,
    help="Create json file log?",
)
USE_FILE_LOG_OPTION = click.option(
    "--log_file",
    is_flag=True,
    type=bool,
    default=True,
    help="Create json file log?",
)
TYPE_OF_OUTPUT_OPTION = click.option(
    "--output",
    type=Outputs,
    default=Outputs.TENSORBOARD,
    help="Which type of output do you want?",
)
ANALYZE_OPTION = click.option(
    "--analyze",
    type=bool,
    default=False,
    is_flag=True,
    help="Should we analyze the run data",
)
ADDITIONAL_LOGS_OPTION = click.option(
    "--alog", type=bool, default=False, is_flag=True, help="Add logs in handlers"
)
FUNCTION_NUMBER_OPTION = lambda multiple=True: click.option(
    "--func_num",
    type=int,
    default=None,
    multiple=multiple,
    help="Do you want to run only a specific function?",
)
FUNCTION_DIM_OPTION = lambda multiple=True: click.option(
    "--func_dim",
    type=int,
    default=None,
    multiple=multiple,
    help="Do you want to run only a specific dim?",
)
FUNCTION_INSTANCE_OPTION = click.option(
    "--func_inst",
    type=int,
    default=None,
    multiple=True,
    help="Do you want to run only a specific instance?",
)
SPACE_CLASS_OPTION = click.option(
    "--space_name",
    type=str,
    default="MatrixSpace",
    help="The name of the class for the space",
)
DEVICE_OPTION = click.option(
    "--device",
    type=int,
    default=0 if torch.cuda.is_available() else None,
    help="index of the device to use, if available",
)
RESULT_PATH_OPTION = click.option(
    "-r",
    "--results_path",
    default=((BASE_PATH / "results"), StoreTypes.SQLITE_HIERARCHY),
    type=click.types.Tuple([Path, StoreTypes]),
    help="The location of the algorithm result",
)
MIN_DIFF_OPTION = click.option(
    "--min_diff",
    type=float,
    default=1e-2,
    help="Minimum difference between the best value of the function and the best global value",
)

NETWORK_SIZE_OPTION = click.option(
    "--net_size",
    type=str,
    default="",
    callback=lambda ctx, param, value: [int(val) for val in value.split(",") if val],
)
MODEL_OPTIMIZER_LR_OPTION = click.option(
    "--model_lr", type=float, default=0.01, help="The learning rate for the model"
)
VALUE_OPTIMIZER_LR_OPTION = click.option(
    "--grad_lr",
    type=float,
    default=1e-3,
    help="The learning rate for gradient evaluator",
)
SET_OPTION = click.option(
    "--setn",
    type=click.Tuple([str, str]),
    multiple=True,
    default=None,
    help="Additional parameters to set for the algorithm",
)
BINS_OPTION = click.option("--bins", default=20, type=int)
PERCENTILE_OPTION = click.option("--percentile", default=0.01, type=float)
PLOT_IN_LINE_OPTION = click.option("--num_plot_in_line", type=int, default=None)


def space_instance(func):
    @SPACE_CLASS_OPTION
    @FUNCTION_DIM_OPTION()
    @DEVICE_OPTION
    @BUDGET_OPTION
    def space_creator_dec(
        *args, space_name: str, func_dim: Tuple[int], device: int, budget: int, **kwargs
    ):
        torch.manual_seed(0)
        space_class = find_class_by_name(algorithms, space_name)
        lower = -5.0
        upper = 5.0
        space = [
            space_class(
                torch_uniform(-1, 1, (dim, dim), torch.float64, device),
                torch_uniform(-1, 1, (dim,), torch.float64, device),
                torch_uniform(lower, upper, (1,), torch.float64, device).squeeze(),
                torch.tensor([lower] * dim, dtype=torch.float64, device=device),
                torch.tensor([upper] * dim, dtype=torch.float64, device=device),
                budget=budget,
            )
            for dim in func_dim
        ]
        return func(
            *args,
            space=space if len(space) > 1 else space[0],
            device=device,
            budget=budget,
            **kwargs,
        )

    options = getattr(func, "__click_params__", [])
    space_creator_dec.__click_params__ += options
    return space_creator_dec


def output_creator(func):
    @TYPE_OF_OUTPUT_OPTION
    def output_creator_dec(*args, output: Outputs, **kwargs):
        if output == Outputs.TENSORBOARD:
            output_type = TensorBoardOutput(WriterFactory)
        elif output == Outputs.IMAGE:
            output_type = ImageOutput()
        else:
            raise NotImplementedError("No such output")
        return func(*args, output_type=output_type, **kwargs)

    options = getattr(func, "__click_params__", [])
    output_creator_dec.__click_params__ += options
    return output_creator_dec


def options_from_dict(param_dict: dict, exclude=None) -> List[Option]:
    exclude = exclude or []
    return [
        Option(
            [f"--{config_name}"],
            type=config_value[0] if config_value[0] in PRIMITIVES else str,
            default=None,
        )
        for config_name, config_value in param_dict.items()
        if config_name not in exclude
    ]


def option_types_from_dict(param_dict: dict, exclude=None) -> List[Option]:
    return [
        Option([f"--{PARAM_TYPE_FORMAT.format(config_name)}"], type=str, default=None)
        for config_name in param_dict.keys()
        if config_name not in exclude
    ]


class AlgorithmCommand(MultiCommand):
    def __init__(
        self,
        algorithm_module: ModuleType,
        algorithm_command: Callable,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.algorithm_module = algorithm_module
        self.algorithm_command = algorithm_command
        self.algorithms_types = {
            algorithm.ALGORITHM_NAME: algorithm
            for algorithm in find_subclasses(self.algorithm_module, Algorithm)
        }
        self.configurable = [
            klass
            for klass in find_subclasses(self.algorithm_module, Configurable)
            if not issubclass(klass, Algorithm)
        ]
        self.init_signature = {
            algorithm.ALGORITHM_NAME: get_signature_parameters(
                algorithm, algorithm, Algorithm
            )
            for algorithm in self.algorithms_types.values()
        }
        self.train_signature = {
            algorithm.ALGORITHM_NAME: get_signature_parameters(
                getattr(algorithm, Algorithm.train.__name__), algorithm, Algorithm
            )
            for algorithm in self.algorithms_types.values()
        }

    def list_commands(self, ctx: Context) -> List[str]:
        return list(self.init_signature.keys())

    def get_command(self, ctx: Context, cmd_name: str) -> Optional[Command]:
        if cmd_name in self.init_signature:
            alg_command = functools.partial(self.algorithm_command, algorithm=cmd_name)
            algorithm = self.algorithms_types[cmd_name]
            ignored_params = algorithm.ignored_params()

            additional_parameters = {
                param_name: (type(value), value)
                for param_name, value in algorithm.additional_parameters().items()
            }
            parameters = (
                self.init_signature[cmd_name]
                | self.train_signature.get(cmd_name)
                | additional_parameters
            )
            for param_name, default_type in algorithm.default_types().items():
                parameters[param_name] = (default_type, parameters[param_name][1])

            params = options_from_dict(parameters, ignored_params)
            params += option_types_from_dict(parameters, ignored_params)
            params += getattr(self.algorithm_command, "__click_params__", [])
            params += getattr(self.algorithm_command, "params", [])
            return Command(cmd_name, params=params, callback=alg_command)
