import os
import numpy as np
import pandas as pd
import statsmodels.api as sm
from sklearn.metrics import log_loss
from tqdm import tqdm

def logistic_coef_loss_ci(
        X: np.ndarray, 
        y: np.ndarray, 
        fit_intercept: bool = False,
        n_bootstrap: int = 2000, 
        alpha: float = 0.05, 
    ) -> tuple[np.ndarray, np.ndarray, float, tuple[float, float]]:
    """
    Bootstraps the coefficients of a logistic regression model.

    :param X: The independent variables of the logistic regression model.
    :type X: np.ndarray
    :param y: The response variable of the logistic regression model.
    :type y: np.ndarray
    :param n_bootstrap: The number of bootstrap iterations, defaults to 2000.
    :type n_bootstrap: int, optional
    :param alpha: The significance level for the confidence intervals, defaults to 0.05.
    :type alpha: float, optional
    :return: The coefficients of the logistic regression model, the Wald confidence intervals of the coefficients, the log loss of the model, and the bootstrap confidence intervals for the log loss.
    :rtype: tuple[np.ndarray, np.ndarray, float, tuple[float, float]]
    """
    silent = eval(os.environ.get("SILENT", "True"))
    
    X = np.array(X)
    y = np.array(y)
    if X.ndim == 1:
        X = X.reshape(-1, 1)
    if y.ndim == 2:
        assert y.shape[1] == 1 or y.shape[0] == 1, "y must be a 1D array"
        y = y.flatten()
    
    if not silent:
        print("--- Data ---")
        print(f"X shape: {X.shape}, y shape: {y.shape}")
        print(f"Class distribution: {np.bincount(y)}")

    # --- 2. Fit Model with Statsmodels (for detailed output & Wald CIs) ---
    if fit_intercept:
        X_const = sm.add_constant(X, prepend=True)
    else:
        X_const = X
    sm_model = sm.Logit(y, X_const)
    sm_results = sm_model.fit(disp=0)
    
    # Extract point estimates and Wald CIs
    sm_params = sm_results.params
    sm_wald_ci = sm_results.conf_int(alpha=alpha)
    sm_logloss = -sm_results.llf / sm_results.nobs
    
    if not silent:
        print("\n--- Statsmodels Fit ---")
        print(sm_results.summary())
        print(f"\nStatsmodels Point Estimate Log Loss: {sm_logloss:.4f}")

    # Function for Loss CIs
    def bootstrap_logistic_loss_ci(X, y, alpha=0.05, initial_params=None):
        n_samples = len(y)
        bootstrap_losses = np.zeros(n_bootstrap)
        if fit_intercept:
            X_const_orig = sm.add_constant(X, prepend=True)
        else:
            X_const_orig = X

        fit_options = {'disp': 0}
        if initial_params is not None:
            fit_options['start_params'] = initial_params

        for i in tqdm(range(n_bootstrap)):
            indices = np.random.choice(n_samples, n_samples, replace=True)
            X_boot_const = X_const_orig[indices]
            y_boot = y[indices]
            try:
                model_boot = sm.Logit(y_boot, X_boot_const)
                results_boot = model_boot.fit(**fit_options)
                pred_proba_orig = results_boot.predict(X_const_orig)
                epsilon = 1e-15
                pred_proba_orig = np.clip(pred_proba_orig, epsilon, 1 - epsilon)
                loss = log_loss(y, pred_proba_orig)
                bootstrap_losses[i] = loss
            except Exception:
                bootstrap_losses[i] = np.nan

        lower_p = (alpha / 2) * 100
        upper_p = (1 - alpha / 2) * 100
        
        if np.all(np.isnan(bootstrap_losses)):
            print("Warning: All bootstrap iterations failed for loss calculation.")
            return np.nan, np.nan
            
        ci_lower = np.nanpercentile(bootstrap_losses, lower_p)
        ci_upper = np.nanpercentile(bootstrap_losses, upper_p)
        return ci_lower, ci_upper

    # --- Execute Bootstrapping ---
    if not silent:
        print(f"\n--- Running {n_bootstrap} Bootstrap Iterations ---")
    
    bootstrap_loss_ci = bootstrap_logistic_loss_ci(X, y, initial_params=sm_results.params)

    # Coefficients Table
    if not silent:
        print("\n--- Table 1: Comparison of Coefficient Confidence Intervals (95%) ---")
        comparison_data = {
            "Parameters": sm_params.tolist(),
            "Coef CI (lower)": sm_wald_ci[:, 0].tolist(),
            "Coef CI (upper)": sm_wald_ci[:, 1].tolist(),
        }
        print(pd.DataFrame(comparison_data))
        print("------------------------------------------------------------------")

    # Log Loss CI
    if not silent:
        print("\n--- Cross-Entropy Loss (Log Loss) ---")
        print(f"Point Estimate (Statsmodels): {sm_logloss:.4f}")
        if not np.any(np.isnan(bootstrap_loss_ci)):
            print(f"Bootstrap 95% CI: ({bootstrap_loss_ci[0]:.4f}, {bootstrap_loss_ci[1]:.4f})")
        else:
            print("Bootstrap CI for Log Loss could not be computed.")
    
    return sm_params.tolist(), sm_wald_ci.tolist(), sm_logloss, bootstrap_loss_ci