import functools
import os
import typing
from logging import Logger
from pathlib import Path
from typing import Type, TypeVar, Tuple, List

import torch
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import algorithms
from algorithms.convergence_algorithms.base import Configurable, Algorithm
from algorithms.convergence_algorithms.basic_config import ConfigValue
from algorithms.convergence_algorithms.typing import (
    BoundedEvaluatedSamplerIdentifiableSpace,
)
from compute_result.factory import StoreTypes
from utils.algorithms_data import (
    algorithms_map,
    algorithm_train_params,
)
from utils.dynamically_load_class import (
    get_full_signature_parameters,
    need_params_for_signature,
    find_class_by_name,
)
from utils.python import timestamp_file_signature
from problems.utils import bounds_from_space
from run_options import PARAM_TYPE_FORMAT
from scripts.callbacks import (
    create_algorithm_analyzer_handlers,
    create_save_run_handlers,
    full_convergence_drawer,
)

T = TypeVar("T", bound="Configurable")


def update_params_from_class_default_values(
    signature, klass: Type[T], parameters: dict, additional_default_values: dict = None
) -> dict:
    additional_default_values = additional_default_values or {}
    default_types = klass.default_types()
    ignored_params = klass.ignored_params()
    default_values = klass.class_default_values() | additional_default_values
    for param, value in get_full_signature_parameters(signature, Configurable).items():
        if value.kind not in [value.POSITIONAL_OR_KEYWORD, value.KEYWORD_ONLY]:
            continue
        if param in ignored_params:
            continue
        default_type = default_types.get(param, None) or value.annotation
        if new_type_name := parameters.get(PARAM_TYPE_FORMAT.format(param), None):
            default_type = find_class_by_name(algorithms, new_type_name)

        if parameters[param] == "None":
            parameters[param] = None
        elif (
            need_params_for_signature(default_type)
            and parameters.get(param, None) is None
        ):
            parameters[param] = create_object_from_parameters(
                default_type,
                update_params_from_class_default_values(
                    default_type, default_type, parameters, default_values
                ),
            )
        else:
            # First, parameter were the user specifically added for the run
            parameter_value = parameters.get(param, None)
            # Second, parameter from the default value dict
            if parameter_value is None:
                parameter_value = default_values.get(param, None)
            # Third, default value from class
            if parameter_value is None and value.default is not value.empty:
                parameter_value = value.default
            # Last, check if this value is configurable
            if isinstance(parameter_value, ConfigValue):
                parameter_value = parameter_value(**parameters)
            parameters[param] = parameter_value

            # If this is a Type, I load the type
            if typing.get_origin(value.annotation) is type and isinstance(
                parameters[param], str
            ):
                parameters[param] = find_class_by_name(algorithms, parameters[param])

    return parameters


def create_object_from_parameters(klass: Type[T], parameters: dict) -> T:
    return klass(
        **{
            param: parameters[param]
            for param, value in get_full_signature_parameters(
                klass, Configurable
            ).items()
            if value.kind in [value.POSITIONAL_OR_KEYWORD, value.KEYWORD_ONLY]
        }
    )


def manipulate_parameters(parameters: dict, algorithm: Type[Algorithm]) -> dict:
    return parameters | {
        parameter_to_manipulate: manipulation(**parameters)
        for parameter_to_manipulate, manipulation in algorithm.manipulate_parameters().items()
        if parameter_to_manipulate in parameters
    }


def run(
    algorithm_name: str,
    space: BoundedEvaluatedSamplerIdentifiableSpace,
    parameters: dict,
    use_config: List[str],
    start_point: Tensor = None,
    device: int = 0,
    output_bar: bool = True,
    output_graph: bool = True,
    analyze_run: bool = False,
    additional_logs: bool = False,
    run_name: str = "",
    results_path: Tuple[Path, StoreTypes] = None,
    grad_error: bool = False,
    logger: Logger = None,
):
    my_parameters = parameters.copy()
    if not output_bar:
        tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=True)

    torch.set_default_dtype(torch.float64)
    dims = len(space.lower_bound)

    log_dir = (
        Path(f"runs") / f"{timestamp_file_signature()}-{algorithm_name}-{run_name}"
        if run_name
        else None
    )
    writer = SummaryWriter(
        log_dir=str(log_dir), filename_suffix=algorithm_name, comment=algorithm_name
    )
    printed_dims = [0, 1]
    x_lower_bound, x_upper_bound = bounds_from_space(space, printed_dims[0])
    y_lower_bound, y_upper_bound = bounds_from_space(space, printed_dims[1])

    algorithm_class = algorithms_map.get(algorithm_name, None)
    my_parameters["env"] = space
    my_parameters["lower_bounds"] = space.lower_bound
    my_parameters["upper_bounds"] = space.upper_bound
    my_parameters["dims"] = dims
    my_parameters["device"] = device
    my_parameters["logger"] = logger

    if use_config:
        for config in use_config:
            my_parameters = my_parameters | algorithm_class.additional_configs().get(
                config
            )

    my_parameters = update_params_from_class_default_values(
        algorithm_class, algorithm_class, my_parameters
    )
    my_parameters = update_params_from_class_default_values(
        algorithm_class.train, algorithm_class, my_parameters
    )
    my_parameters = manipulate_parameters(my_parameters, algorithm_class)
    use_trust_region = my_parameters.get("input_mapping", None) is not None

    algorithm = create_object_from_parameters(algorithm_class, my_parameters)

    # if use_trust_region:
    #     real_point = algorithm.env.normalize(start_point)
    #     for _ in range(10):
    #         algorithm.input_mapping.squeeze(real_point)
    if start_point is not None:
        algorithm.set_start_point(start_point)

    train_signature = algorithm_train_params[algorithm_name]
    train_parameters = {
        param: value
        for param, value in my_parameters.items()
        if param in train_signature
    }

    logger.info(f"Start running with {algorithm}")
    logger.info(
        f"Train with {os.linesep.join([f'{key}={value}' for key, value in train_parameters.items()])}"
    )
    algorithm.train(
        **train_parameters,
        callback_handlers=[
            *create_save_run_handlers(
                algorithm_name, run_name, use_trust_region, results_path, grad_error
            ),
            *(
                [
                    full_convergence_drawer(
                        space,
                        algorithm_name,
                        writer,
                        dims,
                        printed_dims,
                        x_lower_bound,
                        x_upper_bound,
                        y_lower_bound,
                        y_upper_bound,
                        convert_to_real=use_trust_region,
                    )
                ]
                if output_graph
                else []
            ),
            *(
                create_algorithm_analyzer_handlers(
                    space,
                    writer if analyze_run else None,
                    logger if additional_logs else None,
                    True,
                )
            ),
        ],
    )
