import abc
import logging
import math
from abc import ABC
from logging import Logger
from typing import List, Dict

import numpy as np
import torch
from scipy.optimize import minimize
from torch import Tensor

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.typing import BoundedEvaluatedSpace
from algorithms.convergence_algorithms.utils import float_range
from algorithms.mapping.base import InputMapping
from algorithms.mapping.trust_region import LinearTrustRegion
from algorithms.space.exceptions import NoMoreBudgetError
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable


class ScipyMinimizer(ConvergenceDrawable, Algorithm, ABC):
    def __init__(
        self,
        env: BoundedEvaluatedSpace,
        initial_point: np.ndarray,
        input_mapping: InputMapping = None,
        logger: Logger = logging.getLogger(__name__),
    ):
        self.env = env
        self.curr_point = initial_point
        self.input_mapping = input_mapping
        self.logger = logger

    def set_start_point(self, start_point: Tensor):
        start_point = start_point.cpu().numpy()
        if self.input_mapping:
            start_point = self.input_mapping.map(self.environment.normalize(start_point))
        self.curr_point = start_point

    @property
    def best_point_until_now(self):
        return torch.tensor(self.curr_point)

    @property
    def curr_point_to_draw(self):
        return torch.tensor(self.curr_point)

    @property
    def environment(self):
        return self.env

    @abc.abstractmethod
    def minimize(self, cg_callbacks):
        raise NotImplementedError()

    def train(
        self,
        epochs: int = math.inf,
        no_improvement_limit: int = 20,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
    ):
        self.logger.info(f"Starting cg on {self.env}")
        callback_handlers = callback_handlers or []
        no_improvement_counter = 0
        best_value = torch.inf
        for handler in callback_handlers:
            handler.on_algorithm_start(self)

        def cg_callback(xk):
            self.curr_point = xk
            for c_handler in callback_handlers:
                c_handler.on_epoch_end(self)

        try:
            for _ in float_range(epochs):
                self.logger.info(
                    f"{self.ALGORITHM_NAME} - new iteration with best solution {self.best_point_until_now} for {self.env}"
                )
                res = self.minimize(cg_callback)
                self.curr_point = res.x
                for handler in callback_handlers:
                    handler.on_epoch_end(self)
                if self.environment.is_goal_reached():
                    self.logger.info(
                        f"Finish {self.ALGORITHM_NAME} - best point found in {self.best_point_until_now}"
                    )
                    break
                if best_value <= res.fun:
                    no_improvement_counter += 1
                    self.logger.info(
                        f"No improvement {res.fun} counter {no_improvement_counter}/{no_improvement_limit}"
                    )
                else:
                    no_improvement_counter = 0
                    self.logger.info(f"Improved solution to {res.fun}")
                if (
                    no_improvement_counter >= no_improvement_limit
                    and self.input_mapping is not None
                ):
                    self.logger.info(
                        f"{self.ALGORITHM_NAME} - no improvement limit reached {no_improvement_limit}"
                    )
                    real_best_point = self.input_mapping.inverse(self.curr_point)
                    self.input_mapping.squeeze(real_best_point)
                    self.curr_point = self.input_mapping.map(real_best_point)
        except NoMoreBudgetError:
            return self.curr_point
        finally:
            for handler in callback_handlers:
                handler.on_algorithm_end(self)

    @classmethod
    def _default_types(cls) -> Dict[str, type]:
        return {"input_mapping": LinearTrustRegion}


class ConjugateGradient(ScipyMinimizer):
    ALGORITHM_NAME = "cg"

    def minimize(self, cg_callbacks):
        return minimize(
            self.env, self.curr_point, method="CG", jac=False, callback=cg_callbacks
        )
