import torch
from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from .nonparamcdf import kernel


class TorchModelWrapper():
    """A wrapper for Sklearn models to be used with PyTorch."""
    def __init__(self, model: BaseEstimator, min=-torch.inf, max=torch.inf, **model_args):
        self.model = model(**model_args)
        self.model_args = model_args
        self.model_name = model.__class__.__name__
        self._is_fitted = False
        self.min = min
        self.max = max

    def fit(self, X: torch.Tensor, y: torch.Tensor):
        X = X.cpu().numpy()
        y = y.cpu().numpy()
        self.model.fit(X, y)
        self._is_fitted = True

    def predict(self, X: torch.Tensor):
        device = X.device
        X = X.cpu().numpy()
        if hasattr(self.model, "predict_proba"):
            y = self.model.predict_proba(X)[:, 1]
        else:
            y = self.model.predict(X)
        return torch.clamp(torch.tensor(y, device=device), min=self.min, max=self.max)


class KernelLogisticRegression():
    """Kernel Logistic Regression model."""
    def __init__(self, kernel: kernel.Kernel, num_points=None, **kwargs):
        self.logregmodel = LogisticRegression(penalty="l2", **kwargs)
        self.kernel = kernel
        self._is_fitted = False
        self.num_points = num_points

    def fit(self, X: torch.Tensor, y: torch.Tensor):
        if self.num_points is None:
            self.X_fit = X.detach().clone()
        else:
            subset_ind = torch.randperm(X.shape[0])[:self.num_points]
            self.X_fit = X[subset_ind].detach().clone()
        kernel_X = self.kernel.eval(X, self.X_fit)
        self.logregmodel.fit(kernel_X.cpu().numpy(), y.cpu().numpy())
        self._is_fitted = True

    def predict(self, X: torch.Tensor):
        device = X.device
        new_kernel_X = self.kernel.eval(X, self.X_fit)
        probs = self.logregmodel.predict_proba(new_kernel_X.cpu().numpy())[:, 1]
        return torch.tensor(probs, device=device, dtype=torch.float32)
