import functools
from pathlib import Path

from torch import Tensor

from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from compute_result.typing import ProblemSpace
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable

MIN_MAX_FILE_ENDING = "_min_max"


def min_max_result_file_path(dir_path: Path, alg_name: str) -> Path:
    return dir_path / f"{alg_name}{MIN_MAX_FILE_ENDING}"


def construct_result_file_name(alg_name: str, env_name: str) -> str:
    return f"{alg_name}-{env_name}"


def write_best_result(best_result_path: Path, min_result: float, max_result: float):
    best_result_path.write_text(f"{min_result},{max_result}")


def problem_space_from_alg(alg: ConvergenceDrawable) -> ProblemSpace:
    return (
        alg.environment.suite,
        alg.environment.func_id,
        alg.environment.dimension,
        alg.environment.func_instance,
    )


def map_parameter_to_real(func):
    @functools.wraps(func)
    def map_to_real_wrapper(alg: ConvergenceAlgorithm, *args, **kwargs):
        # Map parameters
        if not hasattr(alg, "input_mapping"):
            return func(alg, *args, **kwargs)
        for parameter_name, parameter_value in kwargs.items():
            if isinstance(parameter_value, Tensor):
                kwargs[parameter_name] = alg.environment.denormalize(
                    alg.input_mapping.inverse(parameter_value)
                )

        return func(alg, *args, **kwargs)

    return map_to_real_wrapper


def map_to_real(func):
    @functools.wraps(func)
    def map_to_real_wrapper(alg: ConvergenceAlgorithm, *args, **kwargs):
        # Map return value
        points = func(alg, *args, **kwargs)
        if alg.input_mapping:
            points = alg.input_mapping.inverse(points)
            points = alg.environment.denormalize(points)
        return points

    return map_to_real_wrapper


def convert_to_real_class(klass):
    class_callables = [
        method_name for method_name in dir(klass) if callable(getattr(klass, method_name))
    ]
    points_generator_callables = [
        method_name
        for method_name in class_callables
        if "points" in method_name or "point" in method_name
    ]

    # Convert points function
    for points_gen_method_name in points_generator_callables:
        point_gen_method = getattr(klass, points_gen_method_name)
        mapped_points_gen_method = map_to_real(map_parameter_to_real(point_gen_method))
        setattr(klass, points_gen_method_name, mapped_points_gen_method)


def convert_to_real_handler(handler: AlgorithmCallbackHandler):
    convert_to_real_class(handler)

    handler.on_algorithm_start = map_parameter_to_real(handler.on_algorithm_start)
    handler.on_epoch_end = map_parameter_to_real(handler.on_epoch_end)
    handler.on_algorithm_update = map_parameter_to_real(handler.on_algorithm_update)
    handler.on_algorithm_end = map_parameter_to_real(handler.on_algorithm_end)
    return handler
