# import numpy as np
# import sklearn
# import matplotlib.pyplot as plt
# import gpflow
# import scipy
# from misc import gp_classes
# from sklearn.utils.validation import check_is_fitted
# # import tensorflow as tf
# import tensorflow.compat.v1 as tf

# class CalibrationMethod(sklearn.base.BaseEstimator):
#     """
#     A generic class for probability calibration

#     A calibration method takes a set of posterior class probabilities and transform them into calibrated posterior
#     probabilities. Calibrated in this sense means that the empirical frequency of a correct class prediction matches its
#     predicted posterior probability.
#     """

#     def __init__(self):
#         super().__init__()

#     def fit(self, X, y):
#         """
#         Fit the calibration method based on the given uncalibrated class probabilities X and ground truth labels y.

#         Parameters
#         ----------
#         X : array-like, shape (n_samples, n_classes)
#             Training data, i.e. predicted probabilities of the base classifier on the calibration set.
#         y : array-like, shape (n_samples,)
#             Target classes.

#         Returns
#         -------
#         self : object
#             Returns an instance of self.
#         """
#         raise NotImplementedError("Subclass must implement this method.")

#     def predict_proba(self, X):
#         """
#         Compute calibrated posterior probabilities for a given array of posterior probabilities from an arbitrary
#         classifier.

#         Parameters
#         ----------
#         X : array-like, shape (n_samples, n_classes)
#             The uncalibrated posterior probabilities.

#         Returns
#         -------
#         P : array, shape (n_samples, n_classes)
#             The predicted probabilities.
#         """
#         raise NotImplementedError("Subclass must implement this method.")

#     def predict(self, X):
#         """
#         Predict the class of new samples after scaling. Predictions are identical to the ones from the uncalibrated
#         classifier.

#         Parameters
#         ----------
#         X : array-like, shape (n_samples, n_classes)
#             The uncalibrated posterior probabilities.

#         Returns
#         -------
#         C : array, shape (n_samples,)
#             The predicted classes.
#         """
#         return np.argmax(self.predict_proba(X), axis=1)

#     def plot(self, filename, xlim=[0, 1], **kwargs):
#         """
#         Plot the calibration map.

#         Parameters
#         ----------
#         xlim : array-like
#             Range of inputs of the calibration map to be plotted.

#         **kwargs :
#             Additional arguments passed on to :func:`matplotlib.plot`.
#         """
#         # TODO: Fix this plotting function

#         # Generate data and transform
#         x = np.linspace(0, 1, 10000)
#         y = self.predict_proba(np.column_stack([1 - x, x]))[:, 1]

#         # Plot and label
#         plt.plot(x, y, **kwargs)
#         plt.xlim(xlim)
#         plt.xlabel("p(y=1|x)")
#         plt.ylabel("f(p(y=1|x))")

# class GPCalibration(CalibrationMethod):
#     """
#     Probability calibration using a latent Gaussian process

#     Gaussian process calibration [1]_ is a non-parametric approach to calibrate posterior probabilities from an arbitrary
#     classifier based on a hold-out data set. Inference is performed using a sparse variational Gaussian process
#     (SVGP) [2]_ implemented in `gpflow` [3]_.

#     Parameters
#     ----------
#     n_classes : int
#         Number of classes in calibration data.
#     logits : bool, default=False
#         Are the inputs for calibration logits (e.g. from a neural network)?
#     mean_function : GPflow object
#         Mean function of the latent GP.
#     kernel : GPflow object
#         Kernel function of the latent GP.
#     likelihood : GPflow object
#         Likelihood giving a prior on the class prediction.
#     n_inducing_points : int, default=100
#         Number of inducing points for the variational approximation.
#     maxiter : int, default=1000
#         Maximum number of iterations for the likelihood optimization procedure.
#     n_monte_carlo : int, default=100
#         Number of Monte Carlo samples for the inference procedure.
#     max_samples_monte_carlo : int, default=10**7
#         Maximum number of Monte Carlo samples to draw in one batch when predicting. Setting this value too large can
#         cause memory issues.
#     inf_mean_approx : bool, default=False
#         If True, when inferring calibrated probabilities, only the mean of the latent Gaussian process is taken into
#         account, not its covariance.
#     session : tf.Session, default=None
#         `tensorflow` session to use.
#     random_state : int, default=0
#         Random seed for reproducibility. Needed for Monte-Carlo sampling routine.
#     verbose : bool
#         Print information on optimization routine.

#     References
#     ----------
#     .. [1] Wenger, J., Kjellström H. & Triebel, R. Non-Parametric Calibration for Classification in
#            Proceedings of AISTATS (2020)
#     .. [2] Hensman, J., Matthews, A. G. d. G. & Ghahramani, Z. Scalable Variational Gaussian Process Classification in
#            Proceedings of AISTATS (2015)
#     .. [3] Matthews, A. G. d. G., van der Wilk, M., et al. GPflow: A Gaussian process library using TensorFlow. Journal
#            of Machine Learning Research 18, 1–6 (Apr. 2017)
#     """

#     def __init__(self,
#                  n_classes,
#                  logits=False,
#                  mean_function=None,
#                  kernel=None,
#                  likelihood=None,
#                  n_inducing_points=10,
#                  maxiter=1000,
#                  n_monte_carlo=100,
#                  max_samples_monte_carlo=10 ** 7,
#                  inf_mean_approx=False,
#                  session=None,
#                  random_state=1,
#                  verbose=False):
#         super().__init__()

#         # Parameter initialization
#         self.n_classes = n_classes
#         self.verbose = verbose

#         # Initialization of tensorflow session
#         self.session = session

#         # Initialization of Gaussian process components and inference parameters
#         self.input_dim = n_classes
#         self.n_monte_carlo = n_monte_carlo
#         self.max_samples_monte_carlo = max_samples_monte_carlo
#         self.n_inducing_points = n_inducing_points
#         self.maxiter = maxiter
#         self.inf_mean_approx = inf_mean_approx
#         self.random_state = random_state
#         np.random.seed(self.random_state)  # Set seed for optimization of hyperparameters

#         with gpflow.defer_build():

#             # Set likelihood
#             if likelihood is None:
#                 self.likelihood = gp_classes.MultiCal(num_classes=self.n_classes,
#                                                       num_monte_carlo_points=self.n_monte_carlo)
#             else:
#                 self.likelihood = likelihood

#             # Set mean function
#             if mean_function is None:
#                 if logits:
#                     self.mean_function = gpflow.conditionals.mean_functions.Identity()
#                 else:
#                     self.mean_function = gp_classes.Log()
#             else:
#                 self.mean_function = mean_function

#             # Set kernel
#             if kernel is None:
#                 k_white = gpflow.kernels.White(1, variance=0.01)
#                 if logits:
#                     kernel_lengthscale = 10
#                     k_rbf = gpflow.kernels.RBF(input_dim=1, lengthscales=kernel_lengthscale, variance=1)
#                 else:
#                     kernel_lengthscale = 0.5
#                     k_rbf = gpflow.kernels.RBF(input_dim=1, lengthscales=kernel_lengthscale, variance=1)
#                     # Place constraints [a,b] on kernel parameters
#                     k_rbf.lengthscales.transform = gpflow.transforms.Logistic(a=.001, b=10)
#                     k_rbf.variance.transform = gpflow.transforms.Logistic(a=0, b=5)
#                 self.kernel = k_rbf + k_white
#             else:
#                 self.kernel = kernel

#     def fit(self, X, y):
#         """
#         Fit the calibration method based on the given uncalibrated class probabilities or logits X and ground truth
#         labels y.

#         Parameters
#         ----------
#         X : array-like, shape (n_samples, n_classes)
#             Training data, i.e. predicted probabilities or logits of the base classifier on the calibration set.
#         y : array-like, shape (n_samples,)
#             Target classes.

#         Returns
#         -------
#         self : object
#             Returns an instance of self.
#         """
#         # Check for correct dimensions
#         if X.ndim == 1 or np.shape(X)[1] != self.n_classes:
#             raise ValueError("Calibration data must have shape (n_samples, n_classes).")

#         # Create a new TF session if none is given
#         if self.session is None:
#             self.session = tf.Session(graph=tf.Graph())

#         # Fit GP in TF session
#         with self.session.as_default(), self.session.graph.as_default():
#             if X.ndim == 1:
#                 raise ValueError("Calibration training data must have shape (n_samples, n_classes).")
#             else:
#                 self._fit_multiclass(X, y)
#         return self

#     def _fit_multiclass(self, X, y):
#         # Setup
#         y = y.reshape(-1, 1)

#         # Select inducing points through scipy.cluster.vq.kmeans
#         Z = scipy.cluster.vq.kmeans(obs=X.flatten().reshape(-1, 1),
#                                     k_or_guess=min(X.shape[0] * X.shape[1], self.n_inducing_points, ))[0]

#         # Define SVGP calibration model with multiclass softargmax calibration likelihood
#         self.model = gp_classes.SVGPcal(X=X, Y=y, Z=Z,
#                                         mean_function=self.mean_function,
#                                         kern=self.kernel,
#                                         likelihood=self.likelihood,
#                                         whiten=True,
#                                         q_diag=True)

#         # Optimize parameters
#         opt = gpflow.train.ScipyOptimizer()
#         opt.minimize(self.model, maxiter=self.maxiter, disp=self.verbose)

#         return self

#     def predict_proba(self, X, mean_approx=False):
#         """
#         Compute calibrated posterior probabilities for a given array of posterior probabilities from an arbitrary
#         classifier.

#         Parameters
#         ----------
#         X : array-like, shape=(n_samples, n_classes)
#             The uncalibrated posterior probabilities.
#         mean_approx : bool, default=False
#             If True, inference is performed using only the mean of the latent Gaussian process, not its covariance.
#             Note, if `self.inference_mean_approximation==True`, then the logical value of this option is not considered.

#         Returns
#         -------
#         P : array, shape (n_samples, n_classes)
#             The predicted probabilities.
#         """
#         check_is_fitted(self, "model")

#         if mean_approx or self.inf_mean_approx:

#             # Evaluate latent GP
#             with self.session.as_default(), self.session.graph.as_default():
#                 f, _ = self.model.predict_f(X_onedim=X.reshape(-1, 1))
#                 latent = f.eval().reshape(np.shape(X))

#             # Return softargmax of fitted GP at input
#             return scipy.special.softmax(latent, axis=1)
#         else:

#             with self.session.as_default(), self.session.graph.as_default():
#                 # Seed for Monte_Carlo
#                 tf.set_random_seed(self.random_state)

#                 if X.ndim == 1 or np.shape(X)[1] != self.n_classes:
#                     raise ValueError("Calibration data must have shape (n_samples, n_classes).")
#                 else:
#                     # Predict in batches to keep memory usage in Monte-Carlo sampling low
#                     n_data = np.shape(X)[0]
#                     samples_monte_carlo = self.n_classes * self.n_monte_carlo * n_data
#                     if samples_monte_carlo >= self.max_samples_monte_carlo:
#                         n_pred_batches = np.divmod(samples_monte_carlo, self.max_samples_monte_carlo)[0]
#                     else:
#                         n_pred_batches = 1

#                     p_pred_list = []
#                     for i in range(n_pred_batches):
#                         if self.verbose:
#                             print("Predicting batch {}/{}.".format(i + 1, n_pred_batches))
#                         ind_range = np.arange(start=self.max_samples_monte_carlo * i,
#                                               stop=np.minimum(self.max_samples_monte_carlo * (i + 1), n_data))
#                         p_pred_list.append(tf.exp(self.model.predict_full_density(Xnew=X[ind_range, :])).eval())

#                     return np.concatenate(p_pred_list, axis=0)

#     def latent(self, z):
#         """
#         Evaluate the latent function f(z) of the GP calibration method.

#         Parameters
#         ----------
#         z : array-like, shape=(n_evaluations,)
#             Input confidence for which to evaluate the latent function.

#         Returns
#         -------
#         f : array-like, shape=(n_evaluations,)
#             Values of the latent function at z.
#         f_var : array-like, shape=(n_evaluations,)
#             Variance of the latent function at z.
#         """
#         # Evaluate latent GP
#         with self.session.as_default(), self.session.graph.as_default():
#             f, var = self.model.predict_f(z.reshape(-1, 1))
#             latent = f.eval().flatten()
#             latent_var = var.eval().flatten()

#         return latent, latent_var

#     # def plot_latent(self, z, filename, plot_classes=True, **kwargs):
#     #     """
#     #     Plot the latent function of the calibration method.

#     #     Parameters
#     #     ----------
#     #     z : array-like, shape=(n_evaluations,)
#     #         Input confidence to plot latent function for.
#     #     filename :
#     #         Filename / -path where to save output.
#     #     plot_classes : bool, default=True
#     #         Should classes also be plotted?
#     #     kwargs
#     #         Additional arguments passed on to matplotlib.pyplot.subplots.

#     #     Returns
#     #     -------

#     #     """
#     #     # Evaluate latent GP
#     #     with self.session.as_default(), self.session.graph.as_default():
#     #         f, var = self.model.predict_f(z.reshape(-1, 1))
#     #         latent = f.eval().flatten()
#     #         latent_var = var.eval().flatten()
#     #         Z = self.model.X.value

#     #     # Plot latent GP
#     #     if plot_classes:
#     #         fig, axes = pycalib.texfig.subplots(nrows=2, ncols=1, sharex=True, **kwargs)
#     #         axes[0].plot(z, latent, label="GP mean")
#     #         axes[0].fill_between(z, latent - 2 * np.sqrt(latent_var), latent + 2 * np.sqrt(latent_var), alpha=.2)
#     #         axes[0].set_ylabel("GP $g(\\textnormal{z}_k)$")
#     #         axes[1].plot(Z.reshape((np.size(Z),)),
#     #                      np.matlib.repmat(np.arange(0, self.n_classes), np.shape(Z)[0], 1).reshape((np.size(Z),)), 'kx',
#     #                      markersize=5)
#     #         axes[1].set_ylabel("class $k$")
#     #         axes[1].set_xlabel("confidence $\\textnormal{z}_k$")
#     #         fig.align_labels()
#     #     else:
#     #         fig, axes = pycalib.texfig.subplots(nrows=1, ncols=1, sharex=True, **kwargs)
#     #         axes.plot(z, latent, label="GP mean")
#     #         axes.fill_between(z, latent - 2 * np.sqrt(latent_var), latent + 2 * np.sqrt(latent_var), alpha=.2)
#     #         axes.set_xlabel("GP $g(\\textnormal{z}_k)$")
#     #         axes.set_ylabel("confidence $\\textnormal{z}_k$")

#     #     # Save plot to file
#     #     pycalib.texfig.savefig(filename)
#     #     plt.close()

import tensorflow as tf
import gpflow
import numpy as np
from sklearn.base import BaseEstimator
from scipy.special import softmax
from sklearn.utils.validation import check_is_fitted
from gpflow import kernels, mean_functions

class GPCalibration(BaseEstimator):
    def __init__(self,
                 n_classes,
                 logits=False,
                 mean_function=None,
                 kernel=None,
                 n_inducing_points=10,
                 maxiter=1000,
                 inf_mean_approx=False,
                 random_state=1,
                 verbose=False):
        super().__init__()

        # Parameter initialization
        self.n_classes = n_classes
        self.verbose = verbose

        # Initialization of Gaussian process components and inference parameters
        self.input_dim = n_classes
        self.n_inducing_points = n_inducing_points
        self.maxiter = maxiter
        self.inf_mean_approx = inf_mean_approx
        self.random_state = random_state
        np.random.seed(self.random_state)

        # Set likelihood
        self.likelihood = gpflow.likelihoods.Gaussian()

        # Set mean function
        if mean_function is None:
            if logits:
                self.mean_function = mean_functions.Identity()
            else:
                self.mean_function = mean_functions.Zero()
        else:
            self.mean_function = mean_function

        # Set kernel
        if kernel is None:
            k_white = kernels.White(variance=0.01)
            if logits:
                kernel_lengthscale = 10
                k_rbf = kernels.RBF(lengthscales=kernel_lengthscale, variance=1)
            else:
                kernel_lengthscale = 0.5
                k_rbf = kernels.RBF(lengthscales=kernel_lengthscale, variance=1)
            self.kernel = k_rbf + k_white
        else:
            self.kernel = kernel

        # Initialize model as None
        self.model = None

    # def fit(self, X, y):
    #     if self.model is None:
    #         # Initialization of inducing points
    #         Z = X[np.random.permutation(X.shape[0])[:self.n_inducing_points], :]
            
    #         # Initialize model
    #         self.model = gpflow.models.SVGP(
    #             kernel=self.kernel,
    #             likelihood=self.likelihood,
    #             inducing_variable=Z,
    #             mean_function=self.mean_function,
    #             num_latent_gps=self.n_classes
    #         )
        
    #     # Optimize the model parameters
    #     opt = gpflow.optimizers.Scipy()
    #     opt.minimize(self.model.training_loss, variables=self.model.trainable_variables, options=dict(maxiter=self.maxiter))
    def fit(self, X, y):
        if self.model is None:
            # Initialization of inducing points
            Z = X[np.random.permutation(X.shape[0])[:self.n_inducing_points], :]
            
            # Initialize model
            self.model = gpflow.models.SVGP(
                kernel=self.kernel,
                likelihood=self.likelihood,
                inducing_variable=Z,
                mean_function=self.mean_function,
                num_latent_gps=self.n_classes
            )
        
        # Optimize the model parameters
        opt = gpflow.optimizers.Scipy()
        y = tf.one_hot(y, depth=self.n_classes)
        
        X = tf.cast(X, tf.float64)
        y = tf.cast(y, tf.float64)
        data = (X, y)  # Combine the data into a tuple
        
        closure = lambda: self.model.training_loss(data)  # 손실 함수를 람다 함수로 감쌈
        opt.minimize(closure, variables=self.model.trainable_variables, options=dict(maxiter=self.maxiter))

    def predict_proba(self, X):
        check_is_fitted(self, "model")

        if self.inf_mean_approx:
            # Evaluate latent GP
            X = tf.cast(X, tf.float64)
            f_mean, f_var = self.model.predict_f(X)
            
            # Return softargmax of fitted GP at input
            return softmax(f_mean.numpy(), axis=1)
        else:
            raise NotImplementedError("Approximation with full covariance not yet implemented")

    def predict(self, X):
        return np.argmax(self.predict_proba(X), axis=1)