import logging
import math
from logging import Logger

import torch
from torch import Tensor

from algorithms.convergence_algorithms.utils import ball_perturb
from algorithms.stopping_condition.base import AlgorithmStopCondition


def distance_point_from_min(alg, eps: float, point: Tensor, real: bool = True):
    device = point.device
    dim = len(point)
    n = (math.ceil(0.5 * (dim + 1) * (dim + 2)) - 1) * 2
    x = ball_perturb(point, eps, n, device)
    if alg.input_mapping and real:
        x = alg.environment.denormalize(alg.input_mapping.inverse(x))
    y = alg.environment(x, debug_mode=True)
    h_coefficient = torch.stack(
        [x_i.reshape((len(x_i), 1)) @ x_i.reshape((1, len(x_i))) for x_i in x]
    ).reshape((x.shape[0], -1))
    a = torch.hstack((h_coefficient, x, torch.ones((x.shape[0], 1), device=device)))
    # Solve for the coefficients using the fewest squares
    coefficients, _, _, _ = torch.linalg.lstsq(a, y, rcond=None)
    h = coefficients[: dim ** 2].reshape((dim, dim))
    b = coefficients[dim ** 2 : dim ** 2 + dim]
    c = coefficients[-1]
    h = (h + h.T) / 2
    minimum = -(torch.linalg.inv(h) @ b) / 2
    best_point = point
    if alg.input_mapping and real:
        best_point = alg.environment.denormalize(alg.input_mapping.inverse(best_point))
    distance_from_min = (best_point - minimum).pow(2).sum().sqrt()
    return distance_from_min


class QuadModelMinimumReach(AlgorithmStopCondition):
    REASON = "Quad model figured we at a min point"

    def __init__(
        self, ball_epsilon: float, proximity_to_minima_eps: float, logger: Logger = None
    ):
        self.ball_epsilon = ball_epsilon
        self.proximity_to_minima_eps = proximity_to_minima_eps
        self._last_min = None
        self.logger = logger or logging.getLogger(__name__)
        # self.drawer = drawer

    def should_stop(self, alg, **kwargs) -> bool:
        distance_from_min = self.distance_from_model_min(alg)
        self.logger.info(f"Distance from model min {distance_from_min}")
        return distance_from_min < self.proximity_to_minima_eps

    def distance_from_model_min(self, alg, real=True):
        return distance_point_from_min(alg, self.ball_epsilon, alg.best_point_until_now, real)
