import torch
import numpy as np
import math
import time
from sklearn.linear_model import LogisticRegression

class SD_VI:
    """
    Spectral Decomposed Variational Inference (SD-VI)

    This class implements a novel variational inference scheme for Bayesian logistic
    regression. Instead of standard ELBO maximization via automatic differentiation,
    it employs a Proximal Spectral Optimization (PSO) algorithm. This approach
    directly optimizes the posterior covariance in its spectral domain (eigenvalues
    and eigenvectors), allowing for principled and targeted complexity control.

    The algorithm corresponds to the methodology developed in Section 3 of our paper.
    """

    def __init__(self, n_features,
                 # --- Learning rates are now separated for different components ---
                 lr_mu=0.01, lr_S=0.001,
                 max_iter=1000, tol=1e-5, jitter=1e-6,
                 # --- NEW: Hyperparameters for our Spectral Regularizer (Eq. 5) ---
                 lambda1=1e-2, lambda2=0.0, gamma=1.0,
                 # --- NEW: Monte Carlo samples for Expected Log-Likelihood ---
                 n_mc_samples=1,
                 random_state=None):

        self.random_state = random_state
        if random_state is not None:
            torch.manual_seed(random_state)
            np.random.seed(random_state)

        self.n_features = n_features
        self.lr_mu = lr_mu
        self.lr_S = lr_S  # Step-size η_S in our paper
        self.max_iter = max_iter
        self.tol = tol
        self.jitter = jitter
        self.n_mc_samples = n_mc_samples

        # --- Hyperparameters for the Spectral Cost Function f(σ) ---
        self.lambda1 = lambda1  # λ₁ in Eq. 5
        self.lambda2 = lambda2  # λ₂ in Eq. 5
        self.gamma = gamma  # γ in Eq. 5

        # --- Variational Parameters ---
        # We now directly parameterize the full mean and covariance
        self.mu = torch.zeros(n_features, dtype=torch.float64, requires_grad=True)
        # Initialize S as a diagonal matrix (a common starting point)
        self.S = torch.eye(n_features, dtype=torch.float64) * 0.01

        self.objective_values = []

    def _spectral_cost_function_grad(self, sigma):
        """
        Computes the gradient of our chosen spectral cost function f(σ) w.r.t. σ.
        This is needed for the proximal operator of a non-trivial f(σ).
        For simple soft-thresholding (f(σ) = λσ), this is just a constant λ.
        Let's implement the gradient for the penalty in Eq. 5 of the paper draft.
        f(σ) = λ₁ * log(1 + σ/γ) / (1 + λ₂ * exp(-γσ))
        NOTE: This derivative is complex. For a first implementation, a simple L1 penalty is safer.
        Let's stick to the L1 penalty for now: f(σ) = λ₁σ. Its derivative is just λ₁.
        """
        # For f(σ) = self.lambda1 * σ, the derivative is self.lambda1
        return self.lambda1

    def _proximal_spectral_map(self, S_intermediate):
        """
        Implements the Proximal Spectral Mapping step (Eq. 8, 9, 10).
        This is the core of the SD-VI algorithm.
        --- VERSION 2: Enhanced with numerical stability fixes ---
        """
        # --- FIX 1: Pre-condition the matrix ---
        # Normalize the intermediate matrix to prevent extreme values.
        # This helps the stability of the eigendecomposition.
        norm_S = torch.linalg.norm(S_intermediate, 'fro')
        if norm_S > 1e6:  # If the norm is excessively large, scale it down
            S_intermediate = S_intermediate / norm_S * 1e6

        # Symmetrize the matrix to ensure it's perfectly symmetric, as numerical
        # errors in the gradient step can introduce tiny asymmetries.
        S_sym = (S_intermediate + S_intermediate.T) / 2.0

        # Add a slightly larger, adaptive jitter before the first attempt
        current_jitter = self.jitter
        S_reg = S_sym + current_jitter * torch.eye(self.n_features, dtype=torch.float64)

        try:
            eigenvalues, eigenvectors = torch.linalg.eigh(S_reg)
        except torch._C._LinAlgError:
            # --- FIX 2: Iterative Jitter Increase ---
            # If it fails, it's likely still ill-conditioned.
            # We will iteratively increase the jitter until eigh succeeds.
            for i in range(5):
                current_jitter *= 10
                print(f"Warning: Eigendecomposition failed. Retrying with larger jitter: {current_jitter}")
                S_reg = S_sym + current_jitter * torch.eye(self.n_features, dtype=torch.float64)
                try:
                    eigenvalues, eigenvectors = torch.linalg.eigh(S_reg)
                    print("Success after increasing jitter.")
                    break  # Success, exit the loop
                except torch._C._LinAlgError:
                    if i == 4:  # If it still fails after 5 retries, raise the error
                        print("FATAL: Eigendecomposition failed after multiple retries.")
                        raise

        # Step 2: Apply the scalar shrinkage function h(σ') to the eigenvalues.
        threshold = self.lr_S * self.lambda1
        shrunk_eigenvalues = torch.relu(eigenvalues - threshold)

        # Step 3: Reassemble the covariance matrix S_{t+1}
        S_new = eigenvectors @ torch.diag(shrunk_eigenvalues) @ eigenvectors.T

        return S_new
    def _compute_expected_log_likelihood(self, X, y, mu, S):
        """
        Computes a stochastic approximation of the Expected Log-Likelihood term.
        E_q[log p(D|β)] using Monte Carlo sampling. This is the first term in Eq. 3.
        """
        # For reproducibility of this specific step
        # Note: In a real scenario, you might want to manage seeds differently
        torch.manual_seed(int(time.time() * 1000) % (2 ** 32 - 1))

        # Cholesky decomposition for efficient sampling: S = L L^T
        # Add jitter for numerical stability
        S_reg = S + self.jitter * torch.eye(self.n_features, dtype=torch.float64)
        try:
            L = torch.linalg.cholesky(S_reg)
        except torch._C._LinAlgError:
            print("Warning: Cholesky failed. S might not be positive definite. Returning large loss.")
            return torch.tensor(-1e20, dtype=torch.float64)

        # Draw n_mc_samples from N(0, I)
        eps = torch.randn(self.n_features, self.n_mc_samples, dtype=torch.float64)

        # Reparameterization trick: β_samples = μ + L * ε
        beta_samples = mu.unsqueeze(1) + L @ eps  # Shape: (n_features, n_mc_samples)

        # Compute logits for all samples: z = X @ β
        logits = X @ beta_samples  # Shape: (n_samples, n_mc_samples)

        # log p(y|β) = y * log(sigmoid(z)) + (1-y) * log(1-sigmoid(z))
        # This can be simplified using log_sigmoid for stability
        log_likelihood_samples = y.unsqueeze(1) * torch.log(torch.sigmoid(logits)) + \
                                 (1 - y.unsqueeze(1)) * torch.log(1 - torch.sigmoid(logits))

        # Sum over data points, then average over MC samples
        expected_log_likelihood = torch.sum(log_likelihood_samples, dim=0).mean()

        return expected_log_likelihood

    def fit(self, X, y, verbose=False):
        """
        Fit the SD-VI model using the Proximal Spectral Optimization (PSO) algorithm.
        --- VERSION 2: Enhanced with Gradient Clipping ---
        """
        start_time = time.time()
        y_01 = y.clone()
        if torch.any(y < 0):
            y_01 = (y + 1) / 2

        # (Initialization part is unchanged)
        if verbose:
            print("Using scikit-learn LogisticRegression for mean initialization...")
        try:
            sklearn_lr = LogisticRegression(C=1e6, fit_intercept=False, max_iter=1000, penalty='l2', solver='liblinear')
            sklearn_lr.fit(X.numpy(), y_01.numpy())
            self.mu.data = torch.tensor(sklearn_lr.coef_[0], dtype=torch.float64)
        except Exception as e:
            print(f"Initialization failed: {e}. Starting from zero mean.")
            self.mu.data.zero_()

        optimizer_mu = torch.optim.AdamW([self.mu], lr=self.lr_mu)

        prev_objective = -float('inf')
        for iteration in range(self.max_iter):
            # === Step 1: Update μ ===
            optimizer_mu.zero_grad()
            log_likelihood_mu = self._compute_expected_log_likelihood(X, y_01, self.mu, self.S.detach())
            loss_mu = -log_likelihood_mu
            loss_mu.backward()
            optimizer_mu.step()

            # === Step 2: Update S ===
            self.S.requires_grad_(True)
            log_likelihood_S = self._compute_expected_log_likelihood(X, y_01, self.mu.detach(), self.S)
            log_likelihood_S.backward()
            grad_S = self.S.grad

            # --- FIX 3: Gradient Clipping ---
            # This is the most crucial fix. It prevents the gradient `grad_S` from
            # having pathologically large values that make S_intermediate ill-conditioned.
            if grad_S is not None:
                torch.nn.utils.clip_grad_norm_(self.S, max_norm=10.0)

            with torch.no_grad():
                S_intermediate = self.S + self.lr_S * grad_S

            if self.S.grad is not None:
                self.S.grad.zero_()
            self.S.requires_grad_(False)

            with torch.no_grad():
                self.S = self._proximal_spectral_map(S_intermediate)

            # (Objective computation and convergence check part is unchanged)
            with torch.no_grad():
                ell = self._compute_expected_log_likelihood(X, y_01, self.mu, self.S)
                spectral_penalty = self.lambda1 * torch.sum(torch.linalg.eigvalsh(self.S))
                current_objective = (ell - spectral_penalty).item()

            self.objective_values.append(current_objective)

            if verbose and (iteration + 1) % 100 == 0:
                rank = torch.sum(torch.linalg.eigvalsh(self.S) > 1e-6).item()
                print(
                    f"Iter {iteration + 1}/{self.max_iter}, Objective: {current_objective:.4f}, Effective Rank: {rank}")

            if abs(current_objective - prev_objective) < self.tol and iteration > 50:
                if verbose:
                    print(f"Convergence reached at iteration {iteration + 1}.")
                    break
            prev_objective = current_objective

        self.runtime = time.time() - start_time
        return self

    # --- Prediction and Evaluation methods can be reused from your original code ---
    def predict_proba(self, X):
        with torch.no_grad():
            # For prediction, we need the posterior predictive distribution.
            # E[p(y=1|x)] ≈ sigmoid(x^T μ) under certain approximations.
            # A more accurate prediction would involve sampling or integration.
            # Let's stick to the mean-based prediction for now.
            logits = X @ self.mu
            proba = torch.sigmoid(logits)
        return proba

    def compute_mse(self, X, y):
        with torch.no_grad():
            y_prob = self.predict_proba(X)
            y_01 = y.clone()
            if torch.any(y < 0):
                y_01 = (y + 1) / 2
            mse = torch.mean((y_prob - y_01) ** 2)
            return mse.item()

    def compute_ece(self, X, y, n_bins=10):
        with torch.no_grad():
            y_prob = self.predict_proba(X)
            y_01 = y.clone()
            if torch.any(y < 0):
                y_01 = (y + 1) / 2
            bin_boundaries = torch.linspace(0, 1, n_bins + 1)
            bin_lowers = bin_boundaries[:-1]
            bin_uppers = bin_boundaries[1:]
            ece = 0.0
            for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
                in_bin = (y_prob > bin_lower) & (y_prob <= bin_upper)
                prop_in_bin = in_bin.float().mean()
                if prop_in_bin > 0:
                    confidence_in_bin = y_prob[in_bin].mean()
                    accuracy_in_bin = y_01[in_bin].float().mean()
                    ece += torch.abs(confidence_in_bin - accuracy_in_bin) * prop_in_bin
            return ece.item()