from pathlib import Path
from typing import Type, List, Tuple

from torch.utils.tensorboard import SummaryWriter

from compute_result.factory import StoreFactory, StoreTypes
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawer_handlers import TensorboardDrawerHandler, LoggerDrawerHandler
from handlers.drawers.base_drawer import MultipleDrawerEpoch
from handlers.drawers.convergence_drawer import (
    FullFigConvergenceDrawer,
    BestModelDrawer,
    SamplesDrawer,
    ConvergenceDrawer,
)
from handlers.drawers.draw_gradient import GradientDrawer
from handlers.drawers.draw_regions import TrustRegionDrawer
from handlers.drawers.loss_drawer import (
    StepSizeDrawer,
    GradientLossDrawer,
    GradSizeDrawer,
)
from handlers.drawers.utils import convert_to_real_drawer
from handlers.metrics_handlers import GradientErrorMetricHandler, GradientMetric
from handlers.save_run import SaveRunToFileHandler, SaveMinMaxHandler, SaveRunLosses
from handlers.utils import convert_to_real_handler
from utils.algorithms_data import Algorithms


def full_convergence_drawer(
    space,
    alg_name: str,
    writer: SummaryWriter,
    dim_size: int,
    dim_to_print: List[int],
    x_lower_bound: float,
    x_upper_bound: float,
    y_lower_bound: float,
    y_upper_bound: float,
    convert_to_real: bool = True,
    with_output_mapping: bool = True,
    figure_drawer_type: Type[ConvergenceDrawer] = FullFigConvergenceDrawer,
    best_model_drawer_type: Type[BestModelDrawer] = BestModelDrawer,
    should_draw_grad: bool = True,
):
    fig_drawer = figure_drawer_type(
        dim_size=dim_size,
        dims=dim_to_print,
        map_output=with_output_mapping,
        x_lower_bounds=x_lower_bound,
        x_upper_bounds=x_upper_bound,
        y_lower_bounds=y_lower_bound,
        y_upper_bounds=y_upper_bound,
    )
    best_model_drawer = best_model_drawer_type(dims=dim_to_print)
    trust_region_drawer = TrustRegionDrawer(dim_size=dim_size, dims=dim_to_print)
    samples_drawer = SamplesDrawer(dims=dim_to_print)
    grad_drawer = GradientDrawer(dims=dim_to_print)
    drawers = [
        convert_to_real_drawer(fig_drawer) if convert_to_real else fig_drawer,
        convert_to_real_drawer(best_model_drawer) if convert_to_real else best_model_drawer,
        convert_to_real_drawer(trust_region_drawer) if convert_to_real else trust_region_drawer,
        convert_to_real_drawer(samples_drawer) if convert_to_real else samples_drawer,
    ]
    if should_draw_grad:
        drawers.append(convert_to_real_drawer(grad_drawer) if convert_to_real else grad_drawer)
    return TensorboardDrawerHandler(
        MultipleDrawerEpoch(drawers),
        writer=writer,
        name=f"{alg_name} {repr(space)}",
    )


def create_save_run_handlers(
    alg_name: str,
    run_name: str,
    convert_to_real: bool = False,
    results_path: Tuple[Path, StoreTypes] = None,
    grad_error: bool = False,
) -> List[AlgorithmCallbackHandler]:
    data_path, store_type = results_path or (
        Path(f"{run_name}_results.db" if run_name else "results.db"),
        StoreTypes.SQLITE,
    )
    result_store = StoreFactory.get_store(store_type=store_type, data_path=data_path)
    gradient_error = [GradientErrorMetricHandler(result_store, 100, run_name)] if grad_error else []
    run_save_handler = SaveRunToFileHandler(run_name, Algorithms(alg_name), result_store)
    min_max_saver_handler = SaveMinMaxHandler(result_store)
    return [
        # GradientMetric(result_store, run_name),
        convert_to_real_handler(run_save_handler) if convert_to_real else run_save_handler,
        convert_to_real_handler(min_max_saver_handler)
        if convert_to_real
        else min_max_saver_handler,
    ] + gradient_error


def create_algorithm_analyzer_handlers(space, writer=None, logger=None, to_real=False):
    step_size_drawer = convert_to_real_drawer(StepSizeDrawer()) if to_real else StepSizeDrawer()
    step_size_drawer_logger = (
        convert_to_real_drawer(StepSizeDrawer()) if to_real else StepSizeDrawer()
    )
    grad_loss_drawer = (
        convert_to_real_drawer(GradientLossDrawer()) if to_real else GradientLossDrawer()
    )
    grad_size_drawer = convert_to_real_drawer(GradSizeDrawer()) if to_real else GradSizeDrawer()
    analyzers = []
    if writer:
        analyzers += [
            TensorboardDrawerHandler(
                step_size_drawer, name=f"step size {space}", writer=writer
            ),
            TensorboardDrawerHandler(
                grad_loss_drawer, name=f"gradient distance {space}", writer=writer
            ),
            TensorboardDrawerHandler(
                grad_size_drawer, name=f"gradient size {space}", writer=writer
            ),
        ]
    if logger:
        analyzers += [
            LoggerDrawerHandler(
                step_size_drawer_logger, logger=logger, name=f"step size {space}"
            ),
        ]
    return analyzers
