from abc import ABC, abstractmethod

import numpy as np
import torch


class Solver(ABC):

    def __init__(self, minimization: bool, positive_y: bool):

        self._minimization = minimization
        self._positive_y = positive_y
        self._calls = 0
        self._update_calls_count = True

    @property
    def is_minimization_problem(self) -> bool:
        return self._minimization

    @property
    def is_y_positive(self) -> bool:
        return self._positive_y

    @property
    def calls(self) -> int:
        return self._calls

    def reset_calls(self) -> None:
        self._calls = 0

    def freeze_calls_count(self) -> None:
        self._update_calls_count = False

    def unfreeze_calls_count(self) -> None:
        self._update_calls_count = True

    def solve_from_torch(self, x: torch.Tensor, y: torch.Tensor,
                         params: dict[str, torch.Tensor]) -> tuple[torch.Tensor, float]:

        x = x.detach().numpy()
        y = y.detach().numpy()
        params = {key: params[key].detach().numpy() for key in params.keys()}

        solution, runtime = self.solve(x, y, params)
        solution = torch.from_numpy(solution).float()

        return solution, runtime

    def compute_metrics_from_torch(self, y: torch.Tensor, solution: torch.Tensor,
                                   params: dict[str, torch.Tensor]) -> dict:
        y = y.detach().numpy()
        solution = solution.detach().numpy()
        params = {key: params[key].detach().numpy() for key in params.keys()}

        return self.compute_metrics(y, solution, params)

    def solve(self, x: np.ndarray, y: np.ndarray, params: dict[str, np.ndarray]) -> tuple[np.ndarray, float]:

        if self._update_calls_count:
            self._calls += 1

        if self._positive_y:
            y = y * (y > 0)

        return self._solve_method(x, y, params)

    @abstractmethod
    def _solve_method(self, x: np.ndarray, y: np.ndarray, params: dict[str, np.ndarray]) -> tuple[np.ndarray, float]:
        pass

    @abstractmethod
    def compute_metrics(self, y: np.ndarray, solution: np.ndarray, params: dict[str, np.ndarray]) -> dict:
        pass
