# %%
from abc import ABC, abstractmethod
from sklearn.neighbors import NearestNeighbors
from sklearn.isotonic import IsotonicRegression
from scipy.interpolate import splrep, splev
from scipy.optimize import fsolve
import numpy as np
import torch
from torch.nn import functional as F  # noqa: F401
from typing import Union
from ..utils import torch_kernel as kernel
import rpy2.robjects as robjects
TT = torch.Tensor
zero = torch.tensor([0.])
# %%


class kernel_regressor(ABC):
    min_weight_val = 1e-43

    def __init__(self, kernel: kernel.Kernel, min: float = -torch.inf, max: float = torch.inf) -> None:
        self.kernel = kernel
        self.min = min
        self.max = max
        self._is_fitted = False

    def fit(self, X: TT, y: TT) -> None:
        self.y = y
        self.X = X
        self._is_fitted = True

    def get_y_weights(self, X_new: TT, normalise=True) -> TT:
        X_dists = self.kernel.eval(X_new, self.X)
        # Clamp values
        X_dists = torch.clamp(X_dists, min=self.min_weight_val)

        if not normalise:
            return X_dists
        else:
            return X_dists/torch.sum(X_dists, dim=1, keepdim=True)

    def predict(self, X_new: TT) -> TT:
        X_dists = self.get_y_weights(X_new)
        preds = torch.sum(X_dists*self.y, dim=1)
        return torch.clamp(preds, min=self.min, max=self.max)

    __call__ = predict


class constant_regressor(ABC):
    def __init__(self):
        self._is_fitted = False

    def fit(self, X: TT, y: TT):
        self.mean = torch.mean(y)

    def predict(self, X_new: TT) -> TT:
        return torch.zeros_like(X_new[..., 0])+self.mean

    __call__ = predict


class faux_regressor(ABC):
    def __init__(self, func, bias=None, sd=None, n: int = None, alpha: float = None, add_bias=True, logit=True,
                 min: float = -torch.inf, max: float = torch.inf, seed: int = None):
        self.func = func
        self.alpha = alpha
        self.n = n
        if sd is None:
            self.sd = 1/n**alpha if add_bias else 0.
        else:
            self.sd = sd

        if bias is None and add_bias:
            self.bias = 1/n**alpha
        elif bias is None and not add_bias:
            self.bias = 0.
        else:
            self.bias = bias

        seed = torch.randint(0, 100000, (1,)).item() if seed is None else seed
        self.seed = seed
        self.logit = logit
        self.min = min
        self.max = max
        self._is_fitted = False

    def fit(self, *args, **kwargs):
        if hasattr(self.func, "fit"):
            self.func.fit(*args, **kwargs)
        self.rng = torch.Generator()
        self.rng.manual_seed(self.seed)
        self._is_fitted = True

    def predict(self, X_new: TT) -> TT:
        if hasattr(self.func, "predict"):
            original_preds = self.func.predict(X_new)
        else:
            original_preds = self.func(X_new)
        errors = torch.randn(original_preds.shape, device=original_preds.device, generator=self.rng)*self.sd+self.bias
        if self.logit:
            return torch.clamp(torch.sigmoid(torch.logit(original_preds)+errors), min=self.min, max=self.max)
        else:
            return torch.clamp(original_preds+errors, min=self.min, max=self.max)

    __call__ = predict


# # KNN Version ##
class knncdf(ABC):
    def __init__(self, n_neighbours, algorithm, prop_func=None, min=0., max=1.):
        self.n_neighbours = n_neighbours
        self.algorithm = algorithm
        self.knn = NearestNeighbors(n_neighbors=self.n_neighbours, algorithm=self.algorithm)
        self.prop_func = prop_func
        self.min = min
        self.max = max
        self._is_fitted = False

    def fit(self, X: TT, y: TT):
        self.y = y
        self.X = X
        self.y_sorted, self.sort_indices = torch.sort(y)
        self.X_data_sorted = X[self.sort_indices, :]
        self.knn.fit(self.X_data_sorted)
        self._is_fitted = True

    def get_nearest_ys(self, X_new: TT, sorted=True):
        indices = torch.tensor(self.knn.kneighbors(X_new, return_distance=False))
        if sorted:
            indices = torch.sort(indices, 1)[0]
        sub_y = self.y_sorted[indices]

        if self.prop_func is not None:
            weights = 1/self.prop_func(self.X_data_sorted)[indices]
            weights = weights/torch.sum(weights, dim=1, keepdim=True)
        else:
            weights = torch.ones_like(sub_y)/self.n_neighbours
        return sub_y, weights

    def cdf(self, y_new: TT, X_new: TT):
        sub_y, weights = self.get_nearest_ys(X_new, sorted=False)
        size_indicator = (sub_y <= y_new.unsqueeze(-1)).float()
        return torch.clamp(torch.sum(size_indicator*weights, dim=1), min=self.min, max=self.max)

    def getallcdfs(self, X_new: TT):
        sub_y, weights = self.get_nearest_ys(X_new, sorted=True)
        return torch.clamp(torch.cumsum(weights, dim=1), min=self.min, max=self.max), sub_y

    def inverse_cdf(self, alpha: Union[float, TT], X_new: TT):
        cdf_vals, sub_ys = self.getallcdfs(X_new)
        if type(alpha) is float:
            alpha = torch.tensor([alpha])
        valid_ys = torch.where(cdf_vals >= alpha.unsqueeze(-1),
                               sub_ys, torch.tensor([torch.inf]))

        out_vals = torch.min(valid_ys, dim=1)[0]
        # Correct for cases with no valid value which currently output inf
        # Instead output maximum of all ys
        # (This theoretically should happen as the largest y-val should always have eCDF 1).
        return torch.minimum(out_vals, sub_ys[:, -1])


# # Kernel Version ##
class kernel_cdf(ABC):
    """Class for kernel based cdf estimation"""
    min_weight_val = 1e-43

    def __init__(self, kernel: kernel.Kernel, prop_func=None, supremum=False, inverse_main=False, min=0., max=1.):
        """Initialise the kernel type as well as the propensity function if necessary.


        Args:
            kernel (kernel.Kernel): The kernel to use for the cdf estimation. Called using the eval method.
            prop_func (Callable, optional): The propensity function. Defaults to None.
            supremum (bool, optional): Whether inverse is done via infimum (default) or supremum. Defaults to False.
        """
        self.inverse_main = inverse_main
        self.kernel = kernel
        self.prop_func = prop_func
        self.supremum = supremum
        self._in_inverse_cdf = False
        self.inverse_main = inverse_main
        self.weights_calc = False
        self.min = min
        self.max = max
        self._is_fitted = False

    def fit(self, X: TT, y: TT):
        self.y = y
        self.X = X
        self.y_sorted, self.sort_indices = torch.sort(self.y)
        self.X_sorted = self.X[self.sort_indices]
        if self.prop_func is not None:
            self.prop_scores = self.prop_func(self.X_sorted)
        else:
            self.prop_scores = torch.ones_like(self.X_sorted[:, 0])-.5
        self._is_fitted = True

    def reset_y_weights(self):
        self.weights_calc = False
        self.X_dists = None

    def get_y_weights(self, X_new: TT, precalc=False, normalise=True) -> TT:
        if self.weights_calc:
            return self.X_dists
        X_dists = self.kernel.eval(X_new, self.X_sorted)
        # Clamp values
        X_dists = torch.clamp(X_dists, min=self.min_weight_val)
        # Re-adjust for propensity scores if necessary
        X_dists = X_dists/self.prop_scores
        if normalise:
            X_dists = X_dists/torch.sum(X_dists, dim=1, keepdim=True)
        if precalc:
            self.weights_calc = True
            self.X_dists = X_dists
        return X_dists

    def getallcdfs(self, X_new: TT):
        y_weights = self.get_y_weights(X_new)
        cumul_weights = torch.cumsum(y_weights, dim=-1)
        # A rearranging for the case of supremum which is only relevant for inverse cdf
        if self.supremum and self._in_inverse_cdf:
            cumul_weights = torch.cat((
                torch.zeros_like(cumul_weights[:, 0:1]),
                cumul_weights[:, :cumul_weights.shape[1]-1]), dim=-1)
        # Return weights and the change points they're associated with
        return torch.clamp(cumul_weights, min=self.min, max=self.max), self.y_sorted

    def cdf(self, y_new: TT, X_new: TT, y_weights: Union[TT, None] = None):
        if y_weights is None:
            # X_new: dim ..., X_sorted: dim -1
            y_weights = self.get_y_weights(X_new)
        # y_new: dim ..., y_sorted: dim -1
        return torch.clamp(
                torch.sum(y_weights*(self.y_sorted <= y_new.unsqueeze(-1)), dim=-1),
                min=self.min, max=self.max)

    def inverse_cdf(self, alpha: Union[float, TT], X_new: TT):
        # Set the flag for the altering of getallcdfs to in supremum inverse.
        self._in_inverse_cdf = True
        cdf_vals, _ = self.getallcdfs(X_new)
        if type(alpha) is float:
            alpha = torch.tensor([alpha])
        y_expanded = self.y_sorted.unsqueeze(0).expand(X_new.shape[0], -1)

        self._in_inverse_cdf = False
        if not self.supremum:
            valid_ys = torch.where(cdf_vals >= alpha.unsqueeze(-1),
                                   y_expanded, torch.tensor([torch.inf]))

            out_vals = torch.min(valid_ys, dim=1)[0]
            # Correct for cases with no valid value which currently output inf
            # Instead output maximum of all ys
            # (This theoretically should happen as the largest y-val should always have eCDF 1).
            return torch.minimum(out_vals, self.y_sorted[-1])
        else:
            valid_ys = torch.where(cdf_vals <= alpha.unsqueeze(-1),
                                   y_expanded, torch.tensor([-torch.inf]))

            out_vals = torch.max(valid_ys, dim=1)[0]
            # Correct for cases with no valid value which currently output inf
            # Instead output maximum of all ys
            # (This theoretically should happen as the smallest y-val should always have eCDF 0).
            self._in_inverse_cdf = False
            return torch.maximum(out_vals, self.y_sorted[0])

    def __call__(self, *args, **kwargs):
        if not self.inverse_main:
            return self.cdf(*args, **kwargs)
        else:
            return self.inverse_cdf(*args, **kwargs)


class constant_cdf(ABC):
    def __init__(self, inverse_main=False):
        self.inverse_main = inverse_main
        self._is_fitted = False
        self._in_inverse_cdf = False

    def fit(self, X: TT, y: TT):
        self.y = y
        self.y_sorted, self.sort_indices = torch.sort(self.y)

    def cdf(self, y_new: TT, X_new: TT):
        return torch.mean((self.y_sorted <= y_new.unsqueeze(-1)).float(), dim=-1)

    def getallcdfs(self, X_new: TT):
        cumul_weights = torch.arange(1, self.y_sorted.shape[0]+1)+torch.zeros_like(X_new[:, 0:1])
        # A rearranging for the case of supremum which is only relevant for inverse cdf
        if self.supremum and self._in_inverse_cdf:
            cumul_weights = torch.cat((
                torch.zeros_like(cumul_weights[:, 0:1]),
                cumul_weights[:, :cumul_weights.shape[1]-1]), dim=-1)
        # Return weights and the change points they're associated with
        return torch.clamp(cumul_weights, min=self.min, max=self.max), self.y_sorted

    inverse_cdf = kernel_cdf.inverse_cdf
    __call__ = kernel_cdf.__call__


class exact_cdf(ABC):
    def __init__(self, CDF, inverse_CDF=None) -> None:
        self.CDF = CDF
        self.inverse_CDF = inverse_CDF
        self._is_fitted = False

    def fit(self, X: TT, y: TT = None) -> None:
        self.y = y
        self.X = X
        self.y_sorted, self.sort_indices = torch.sort(self.y)
        self._is_fitted = True

    def cdf(self, y_new: TT, X_new: TT):
        return self.CDF(y_new, X_new)

    def inverse_cdf(self, alpha: Union[float, TT], X_new: TT):
        if self.inverse_CDF is None:
            raise ValueError("No inverse CDF function defined.")
        return self.inverse_CDF(alpha, X_new)

    def getallcdfs(self, X_new: TT):
        return self.CDF(self.y_sorted, X_new.unsqueeze(-2)), self.y_sorted

    __call__ = cdf


class faux_est_cdf(ABC):
    def __init__(self, CDF, bias=None, sd=None, n: int = None, alpha: float = None, inverse_CDF=None,
                 add_bias=True, logit=True, min: float = -torch.inf, max: float = torch.inf, seed: int = None):
        self.CDF = CDF
        self.alpha = alpha
        self.n = n
        self.inverse_CDF = inverse_CDF
        if sd is None:
            self.sd = 1/n**alpha if add_bias else 0.
        else:
            self.sd = sd

        if bias is None and add_bias:
            self.bias = 1/n**alpha
        elif bias is None and not add_bias:
            self.bias = 0.
        else:
            self.bias = bias
        seed = torch.randint(0, 100000, (1,)).item() if seed is None else seed
        self.seed = seed
        self._is_fitted = False

    def fit(self, X: TT, y: TT = None) -> None:
        self.y = y
        self.X = X
        self.y_sorted, self.sort_indices = torch.sort(self.y)
        if hasattr(self.CDF, "fit"):
            self.CDF.fit(X, y)
        self.rng = torch.Generator()
        self.rng.manual_seed(self.seed)
        self._is_fitted = True

    def cdf(self, y_new: TT, X_new: TT):
        original_CDF = self.CDF(y_new, X_new)
        errors = torch.randn(original_CDF.shape, device=original_CDF.device, generator=self.rng)*self.sd + self.bias
        errored_CDF = torch.sigmoid(torch.logit(original_CDF)+errors)
        return errored_CDF

    def inverse_cdf(self, prob: Union[float, TT], X_new: TT):
        errors = torch.randn(torch.broadcast_shapes(prob.shape, X_new.shape)[:-1],
                             device=X_new.device, generator=self.rng) * self.sd + self.bias
        shifted_probs = torch.sigmoid(torch.logit(prob)-errors)
        if self.inverse_CDF is not None:
            errored_icdf = self.inverse_CDF(shifted_probs, X_new)
        elif hasattr(self.CDF, "inverse_cdf"):
            errored_icdf = self.CDF.inverse_cdf(shifted_probs, X_new)
        else:
            raise ValueError("No inverse CDF function defined.")
        return errored_icdf

    def getallcdfs(self, X_new: TT):
        if hasattr(self.CDF, "getallcdfs"):
            original_CDF, self.y_sorted = self.CDF.getallcdfs(X_new)
        else:
            original_CDF = self.CDF(self.y_sorted, X_new.unsqueeze(-2))
        errors = torch.randn(original_CDF[..., 0:1].shape, device=X_new.device, generator=self.rng)*self.sd + self.bias
        errored_CDF = torch.sigmoid(torch.logit(original_CDF) + errors)
        return errored_CDF, self.y_sorted

    __call__ = cdf


class faux_est_icdf(faux_est_cdf):
    def __init__(self, inverse_CDF, bias=None, sd=None, n: int = None, alpha: float = 0.5, add_bias=True) -> None:
        faux_est_cdf.__init__(self, None, bias=bias, sd=sd, n=n, alpha=alpha,
                              inverse_CDF=inverse_CDF, add_bias=add_bias)

    inverse_cdf = faux_est_cdf.inverse_cdf
    __call__ = inverse_cdf


class InverseLearner(ABC):
    @abstractmethod
    def __init__(self, kernel: kernel.Kernel):
        super().__init__()
        self.kernel = kernel
        self._is_fitted = False

    def get_y_weights(self, X_new: TT):
        X0_dists = self.kernel.eval(X_new, self.X0)
        X1_dists = self.kernel.eval(X_new, self.X1_sorted)
        normaliser = (
            torch.sum(X0_dists, dim=1, keepdim=True)
            + torch.sum(X1_dists, dim=1, keepdim=True))
        # Normalise
        X0_dists.div_(normaliser)
        X1_dists.div_(normaliser)
        return X0_dists, X1_dists

    @abstractmethod
    def fit(self, y0: TT, X0: TT, y1: TT, X1: TT):
        """Fit the model to the data."""
        self.y0 = y0
        self.X0 = X0
        # Sort y1 and X1 for future use
        self.y1_sorted, self.sort_indices_1 = torch.sort(y1)
        self.X1_sorted = X1[self.sort_indices_1, :]

    @abstractmethod
    def get_single_h(self, y0_new: TT, y1_new: TT, X_new: TT):
        """Get the h value for a single new sample."""
        pass

    @abstractmethod
    def get_all_hs(self, y0_new: TT, X_new: TT):
        """Get all h values for a given new sample."""
        pass

    def predict(self, y0_new: TT, X_new: TT, sortcheck=False, linear=False,
                return_hvals=False, fsolve_kwargs=None, **kwargs):
        hs, y1_candidate = self.get_all_hs(y0_new, X_new, **kwargs)
        if sortcheck:
            if not torch.all(y1_candidate == torch.sort(y1_candidate)[0]):
                raise ValueError("y1_candidate is not sorted.")
        # If h kept discrete
        if not linear:
            # Get y1s which give sufficiently large h/ have sufficiently large CDF
            valid_ys = torch.where(hs >= 0, y1_candidate, torch.tensor([torch.inf]))
            # Find the smallest valid y1
            out_vals = torch.min(valid_ys, dim=-1)[0]
            # Correct for cases with no valid value which currently output inf
            # Instead output maximum of all ys
            # (This theoretically should happen as the largest y-val should always have eCDF 1).
            out_ys = torch.minimum(out_vals, y1_candidate[-1])
            if return_hvals:
                return out_ys, self.get_single_h(y0_new, out_ys, X_new)
            else:
                return out_ys
        # If h made continuous via linear interpolation
        else:
            if fsolve_kwargs is None:
                fsolve_kwargs = {}
            if len(hs.shape) > 2:
                raise ValueError("Continuous prediction doesn't support additional batching dimensions.")
            results = []
            h_out = []
            # Iterate over each y0 sample and associated h values
            for h_sub in hs:
                # Define function to optimise over y1 as linear interpolation of h values
                def h_opt(y_opt):
                    return np.interp(y_opt, y1_candidate, h_sub)
                # Solve for y1
                # Ensure start point comfortably inside interpolation region
                start_point = y1_candidate[y1_candidate.shape[0]//2]
                sol, infodict, ier, mesg = fsolve(h_opt, start_point, full_output=True, **fsolve_kwargs)
                results.append(torch.tensor(sol))
                h_out.append(torch.tensor(infodict['fvec']))

            # Combine values and clamp to ensure no strange behaviour outside interpolation region
            y_out = torch.clamp(torch.cat(results, dim=0), y1_candidate[0], y1_candidate[-1])
            h_out = torch.cat(h_out, dim=0)
            if return_hvals:
                return y_out, h_out
            else:
                return y_out


class IPWInverseLearner(InverseLearner):
    def __init__(self, kernel: kernel.Kernel, prop_func=None):
        super().__init__(kernel)
        self.propensity_model = prop_func

    def get_propensities(self):
        if self.propensity_model is None:
            self.prop_scores0 = torch.ones_like(self.X0[:, 0])-.5
            self.prop_scores1 = torch.ones_like(self.X1_sorted[:, 0])-.5
        elif hasattr(self.propensity_model, "predict"):
            self.prop_scores0 = 1-self.propensity_model.predict(self.X0)
            self.prop_scores1 = self.propensity_model.predict(self.X1_sorted)
        else:
            self.prop_scores0 = torch.ones_like(self.X0[:, 0])-.5
            self.prop_scores1 = torch.ones_like(self.X1_sorted[:, 0])-.5

    def fit(self, y0: TT, X0: TT, y1: TT, X1: TT):
        super().fit(y0, X0, y1, X1)
        # Get propensity scores if necessary
        self.get_propensities()
        self._is_fitted = True

    def get_single_h(self, y0_new, y1_new, X_new):
        X0_dists, X1_dists = self.get_y_weights(X_new)
        # Get contribution for A=0 samples
        term_0 = torch.sum(X0_dists*(self.y0 <= y0_new.unsqueeze(-1))/self.prop_scores0, dim=1)
        # Get contribution for A=1 samples for at all jumping points (i.e. y1 values)
        term_1 = torch.sum(X1_dists*(self.y1_sorted <= y1_new.unsqueeze(-1))/self.prop_scores1, dim=1)
        # Get value of h at each jumping point
        h = term_1 - term_0
        return h

    def get_all_hs(self, y0_new, X_new):
        X0_dists, X1_dists = self.get_y_weights(X_new)
        # Get contribution fo A=0 samples
        term_0 = torch.sum(X0_dists*(self.y0 <= y0_new.unsqueeze(-1))/self.prop_scores0, dim=1, keepdim=True)
        # Get contribution for A=1 samples for at all jumping points (i.e. y1 values)
        term_1s = torch.cumsum(X1_dists/self.prop_scores1, dim=1)
        # Get value of h at each jumping point
        hs = term_1s - term_0
        return hs, self.y1_sorted


class DRInverseLearner(IPWInverseLearner):
    def __init__(self, kernel: kernel.Kernel, cdf_0: kernel_cdf, cdf_1: kernel_cdf, prop_func=None):
        IPWInverseLearner.__init__(self, kernel)
        self.cdf_0 = cdf_0
        self.cdf_1 = cdf_1
        self.prop_func = prop_func

    def get_y_weights(self, X_new: TT):
        X0_dists = self.kernel.eval(X_new, self.X0)
        X1_dists = self.kernel.eval(X_new, self.X1_sorted)
        normaliser = (
            torch.sum(X0_dists, dim=1, keepdim=True)
            + torch.sum(X1_dists, dim=1, keepdim=True))
        # Normalise
        X0_dists.div_(normaliser)
        X1_dists.div_(normaliser)
        return X0_dists, X1_dists

    def get_single_h(self, y0_new: TT, y1_new: TT, X_new: TT):
        if not (y0_new.shape == y1_new.shape == X_new.shape[:-1]):
            raise ValueError("""y0_new, y1_new and X_new must have the same dimensions
            (excluding final additional dimension of x_new).""")
        # # Get weights for each fitting sample y given our new sample.
        # X_new: dim 0, X0/1_dists: dim 1.
        X0_dists, X1_dists = self.get_y_weights(X_new)
        # # Get CDFs
        # y0/1_new: dim 0, X0/1: dim 1.
        cdf_vals0 = self.cdf_0.cdf(y0_new.unsqueeze(-1), self.X0)
        cdf_vals01 = self.cdf_0.cdf(y0_new.unsqueeze(-1), self.X1_sorted)
        cdf_vals1 = self.cdf_1.cdf(y1_new.unsqueeze(-1), self.X1_sorted)
        cdf_vals10 = self.cdf_1.cdf(y1_new.unsqueeze(-1), self.X0)
        # # Get inidcators/comparisons
        # y0_new: dim 0, y0: dim 1.
        Z0 = (self.y0 <= y0_new.unsqueeze(-1)).float()
        Z1 = (self.y1_sorted <= y1_new.unsqueeze(-1)).float()

        # # Get final h value
        term_0 = torch.sum(X0_dists*((Z0-cdf_vals0)/self.prop_scores0
                                     + cdf_vals0-cdf_vals10), dim=1)
        term_1 = torch.sum(X1_dists*((Z1-cdf_vals1)/self.prop_scores1
                                     + cdf_vals1-cdf_vals01), dim=1)
        h = term_1-term_0
        return h

    def get_all_hs(self, y0_new: TT, X_new: TT, isotonic=False, check_same=False, slow=False):
        """Get all h values for a given y0_new and X_new.

        Args:
            y0_new (TT): new y0 data to predict h values for.
            X_new (TT): new X data to predict h values for.
            fast (bool, optional): Whether or not to use fast approach. Defaults to False.
            isotonic (bool, optional): Whether or not to project final output to isotonic vector. Defaults to False.
            check_same (bool, optional): Whether to check if dataset for fitting CDF and DR are the same and adjust.
                                         Defaults to False.

        Returns:
            (TT,TT): (h values, y1 values used for h values)
        """
        if check_same:
            same = (torch.allclose(self.cdf_1.y_sorted, self.y1_sorted)
                    and torch.allclose(self.cdf_1.X_sorted, self.X1_sorted))
        else:
            same = False

        # # Get weights for each fitting sample y given our new sample.
        # X_new: dim ..., X0/1_dists: dim -1.
        X0_dists, X1_dists = self.get_y_weights(X_new)

        # # Get CDFs
        # y0_new: dim ..., X0: dim -1.
        cdf_vals0 = self.cdf_0.cdf(y0_new.unsqueeze(-1), self.X0)
        # y0_new: dim ..., X0: dim -1.
        cdf_vals01 = self.cdf_0.cdf(y0_new.unsqueeze(-1), self.X1_sorted)

        # y0_new in dim ..., y0 in dim -1.
        Z0 = (self.y0 <= y0_new.unsqueeze(-1)).float()
        # # Get contribution of A=0 samples
        # y/X_new: dim ..., empty: dim -1
        # Get 0 Term (depending on y0_new)
        term_0 = torch.sum(X0_dists*((Z0-cdf_vals0)/self.prop_scores0+cdf_vals0),
                           dim=-1, keepdim=True)+torch.sum(X1_dists*cdf_vals01, dim=-1, keepdim=True)

        # ### Term 1 Estimation (depending on all y1) ###
        # X1_sorted:dim 0, y1_steps: dim 1
        all_cdf_vals1, y1_cdf_candidate = self.cdf_1.getallcdfs(self.X1_sorted)
        # X0: dim 0, y1_steps: dim 1
        all_cdf_vals10 = self.cdf_1.getallcdfs(self.X0)[0]
        incidicator_term_1 = torch.cumsum(X1_dists/self.prop_scores1, dim=-1)

        # Expand out to add steps at other y1 values
        if not same:
            identity_vec = torch.tensor([0., 1.]).repeat_interleave(
                torch.tensor([self.y1_sorted.shape[0], y1_cdf_candidate.shape[0]]))

            # Set value
            all_y1_candidate, all_sort_indices = torch.sort(torch.cat([self.y1_sorted, y1_cdf_candidate]))
            identity_vec = identity_vec[all_sort_indices]
            # Merging
            # Append 0 to the start of each row
            all_cdf_vals1 = torch.cat([torch.zeros(all_cdf_vals1.shape[0], 1), all_cdf_vals1], dim=1)
            all_cdf_vals10 = torch.cat([torch.zeros(all_cdf_vals10.shape[0], 1), all_cdf_vals10], dim=1)
            incidicator_term_1 = torch.cat([torch.zeros(incidicator_term_1.shape[0], 1), incidicator_term_1], dim=1)

            all_cdf_vals1_expanded = all_cdf_vals1[:, torch.cumsum(identity_vec, dim=0).int()]
            all_cdf_vals10_expanded = all_cdf_vals10[:, torch.cumsum(identity_vec, dim=0).int()]
            # Expand out indicator term to match all_y1_candidate
            incidicator_term_1_expanded = incidicator_term_1[:, torch.cumsum(identity_vec == 0, dim=0).int()]
        else:
            all_cdf_vals1_expanded = all_cdf_vals1
            all_cdf_vals10_expanded = all_cdf_vals10
            all_y1_candidate = y1_cdf_candidate
            incidicator_term_1_expanded = incidicator_term_1

        cdf_pseudo_1 = (1-1/self.prop_scores1.unsqueeze(1))*all_cdf_vals1_expanded
        cdf_pseudo_10 = all_cdf_vals10_expanded
        cdf_term1 = torch.sum(X1_dists.unsqueeze(-1)*cdf_pseudo_1.unsqueeze(0), dim=-2)
        cdf_term10 = torch.sum(X0_dists.unsqueeze(-1)*cdf_pseudo_10.unsqueeze(0), dim=-2)
        term_1s = incidicator_term_1_expanded+cdf_term1+cdf_term10

        hs = term_1s - term_0
        if isotonic:
            temp_h = []
            ir = IsotonicRegression()
            for h_row in hs:
                temp_h.append(torch.tensor(ir.fit_transform(np.arange(h_row.shape[0]), h_row)))
            hs = torch.stack(temp_h, dim=0)
        return hs, all_y1_candidate


class SeparateLearner(InverseLearner):
    def __init__(self, cdf_0: kernel_cdf, cdf_1: kernel_cdf):
        self.cdf_0 = cdf_0
        self.cdf_1 = cdf_1
        self._is_fitted = False

    def fit(self, *args, **kwargs):
        self._is_fitted = True

    def get_single_h(self, y0: TT, y1: TT, X: TT):
        return self.cdf_1.cdf(y1, x) - self.cdf_0.cdf(y0, X)

    def get_all_hs(self, y0: TT, X: TT):
        all_cdfs_1, y1_candidate = self.cdf_1.getallcdfs(X)
        cdf_0 = self.cdf_0.cdf(y0, X)
        return all_cdfs_1 - cdf_0.unsqueeze(-1), y1_candidate


# Create spline regressor using scipy.interpolate splrep and splev
class spline_regressor(ABC):
    def __init__(self, **spline_kwargs):
        self.spline_kwargs = spline_kwargs
        self._is_fitted = False

    def fit(self, X: TT, y: TT):
        self.X = X
        self.y = y
        self.X_sorted, self.sort_indices = torch.sort(X.squeeze())
        self.y_sorted = y[self.sort_indices]
        self.tks = splrep(self.X_sorted.numpy(), self.y_sorted.numpy(), **self.spline_kwargs)
        self._is_fitted = True

    def predict(self, new_X: TT):
        return torch.tensor(splev(new_X.squeeze().numpy(), self.tks))

    __call__ = predict


class Rspline_regressor(ABC):
    def __init__(self, **spline_kwargs):
        self.spline_kwargs = spline_kwargs
        self._is_fitted = False

    def fit(self, X: TT, y: TT):
        self.X = X
        self.y = y
        self.X_sorted, self.sort_indices = torch.sort(X.squeeze())
        self.y_sorted = y[self.sort_indices].numpy()
        self.X_sorted = self.X_sorted.numpy()
        self.r_y = robjects.FloatVector(self.y_sorted)
        self.r_X = robjects.FloatVector(self.X_sorted)

        r_smooth_spline = robjects.r['smooth.spline']
        self.spline = r_smooth_spline(x=self.r_X, y=self.r_y)
        self._is_fitted = True

    def predict(self, new_X: TT):
        r_new_X = robjects.FloatVector(new_X.squeeze().numpy())
        return torch.tensor(robjects.r['predict'](self.spline, robjects.FloatVector(r_new_X)).rx2('y'))

    __call__ = predict


class CQTE_Learner(ABC):
    def __init__(self, regressor: kernel_regressor, prob: float, icdf_0, icdf_1, xpdf_0, xpdf_1, prop_func=None):

        self.prob = prob
        self.icdf_0 = icdf_0
        self.icdf_1 = icdf_1
        self.xpdf_0 = xpdf_0
        self.xpdf_1 = xpdf_1
        self.prop_func = prop_func
        self.internel_regressor = regressor
        self._is_fitted = False

    def fit(self, X: TT, y: TT, a: TT):
        # At minimum, calculate the pseudo-outcomes
        self.y = y
        self.X = X
        self.a = a
        if self.prop_func is None:
            props = torch.ones_like(self.X[:, 0])-.5
        elif hasattr(self.prop_func, "predict"):
            props = self.prop_func.predict(self.X)
        else:
            props = self.prop_func(self.X)

        pdf_vals = a*self.xpdf_1(self.X)+(1-a)*self.xpdf_0(self.X)
        icdf_0_vals = self.icdf_0(self.prob, self.X)
        icdf_1_vals = self.icdf_1(self.prob, self.X)
        icdf_vals = a*icdf_1_vals+(1-a)*icdf_0_vals

        self.pseudo_outcome = (
            1/((props-1+self.a)*pdf_vals)*(self.prob-(self.y <= icdf_vals).float())
            + icdf_1_vals-icdf_0_vals)
        # Fit regression model
        self.internel_regressor.fit(self.X, self.pseudo_outcome)
        self._is_fitted = True

    def predict(self, X_new: TT):
        return self.internel_regressor.predict(X_new)

    __call__ = predict


# %%
if __name__ == "__main__":
    # Test these
    gs_0 = [lambda x: torch.cos(2*x).squeeze(), lambda x: torch.exp(0.3*x).squeeze()]
    gs_1 = [lambda x: (2*torch.cos(2*x)+1.5).squeeze(), lambda x:torch.exp(0.3*x).squeeze()]

    def true_transform(y, x):
        return gs_1[1](x)*(y-gs_0[0](x))/gs_0[1](x)+gs_1[0](x)

    torch.manual_seed(1234)
    y_prex = torch.randn((1000,))
    y_new_prex = torch.randn((200,))
    x = torch.randn((1000, 1))
    x_new = torch.randn((200, 1))

    # Create X dependent ys
    y = y_prex * gs_0[1](x) + gs_0[0](x)
    y_new = y_new_prex * gs_0[1](x_new) + gs_0[0](x_new)

# %%
if __name__ == "__main__":
    # fit knn model
    knn_model = knncdf(10, 'auto')
    knn_model.fit(y, x)
    out = knn_model.cdf(y_new, x_new)
    reverse_out = knn_model.inverse_cdf(0.5, x_new)
# %%
if __name__ == "__main__":
    # fit kernel model
    x_split_1, x_split_2 = torch.split(x, 500, dim=0)
    kernel_model = kernel_cdf(kernel.KGauss(0.1))
    kernel_model.fit(x, y)
    out = kernel_model.cdf(y_new, x_new)
    reverse_out = kernel_model.inverse_cdf(0.5, x_new)
