
import numpy as np
import matplotlib.pyplot as plt
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss, accuracy_score
from sklearn.metrics.pairwise import laplacian_kernel
from sklearn.preprocessing import StandardScaler
from scipy.special import expit
from sklearn.utils import resample

import cvxpy as cp

class NormalizedGradientBoosting:
    def __init__(self, n_estimators=400, learning_rate=0.1, max_depth=1, schedule=0.0):
        self.n_estimators = n_estimators
        self.learning_rate = learning_rate
        self.max_depth = max_depth
        self.models = []
        self.alphas = []
        self.beta = schedule
    
    def fit(self, X, y):
        n_samples = X.shape[0]
        F = np.zeros(n_samples)  
        y_transformed = y
        
        for t in range(1, self.n_estimators + 1):
            F = np.clip(F, -20, 20)
            p = expit(F)
            residuals = y_transformed - p  
            
            model = DecisionTreeRegressor(max_depth=self.max_depth)
            model.fit(X, residuals)
            
            alpha = self.learning_rate / (t + 2)**self.beta
            
            if t<0:
                F+= alpha* model.predict(X)
            else:
                F = (F + alpha * (model.predict(X)-F))/(1-alpha)
            
            self.models.append(model)
            self.alphas.append(alpha)
        return self
    
    def predict_proba(self, X):
        F = np.zeros(X.shape[0])
        for model, alpha in zip(self.models, self.alphas):
            F = F + alpha * (model.predict(X)-F)
        p = expit(F)
        return np.vstack([p,1-p]).T
    
    def staged_predict_proba(self, X):
        F = np.zeros(X.shape[0])
        tot_alpha = 0.0
        t=0
        for model, alpha in zip(self.models, self.alphas):
            tot_alpha += alpha
            if t<0:
                F+= alpha* model.predict(X)
            else:
                F = (F + alpha * (model.predict(X)-F))/(1-alpha)
            #F+= alpha* model.predict(X)
            p = expit(F)
            t += 1
            yield np.vstack([p,1-p]).T
    
    def staged_predict_proba2(self, X):
        F = np.zeros(X.shape[0])
        tot_alpha = 0.0
        OptF = []
        for model, alpha in zip(self.models, self.alphas):
            tot_alpha += alpha
            F = F + alpha * (model.predict(X)-F)
            OptF.append(F)
            #F+= alpha* model.predict(X)
        F = np.zeros(X.shape[0])
        for t, model in enumerate(OptF):
            F+=model
            p = expit(F/(t+1))
            yield np.vstack([p,1-p]).T
            
            
def compute_binned_ece(probs, y_true):
    num_bins = max(int(len(probs) ** (1/3)), 1)
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    bin_indices = np.digitize(probs, bin_boundaries, right=True) - 1
    ece = 0.0
    n = len(probs)
    for i in range(num_bins):
        bin_mask = bin_indices == i
        if np.sum(bin_mask) > 0:
            bin_acc = np.mean(y_true[bin_mask])
            bin_conf = np.mean(probs[bin_mask])
            ece += (np.sum(bin_mask) / n) * np.abs(bin_acc - bin_conf)
    return ece

def compute_kernel_ece(probs, y_true, bandwidth=0.1):
    probs = np.clip(probs, 1e-6, 1 - 1e-6)
    n = len(probs)
    ece = 0.0
    for i in range(n):
        kernel_weights = np.exp(-((probs - probs[i]) ** 2) / (2 * bandwidth**2))
        kernel_weights /= np.sum(kernel_weights)
        bin_acc = np.sum(kernel_weights * y_true)
        bin_conf = np.sum(kernel_weights * probs)
        ece += np.abs(bin_acc - bin_conf)
    return ece / n

def compute_mmce(probs, y_true, gamma=1.0):
    K = laplacian_kernel(probs.reshape(-1, 1), probs.reshape(-1, 1), gamma=gamma)
    err = probs - y_true
    mmce = np.sqrt(np.sum((err[:, None] * K * err[None, :])) / (len(probs)**2))
    return mmce

def dual_LinECE_fast(v, y):
    n = len(v)
    v = np.array(v)
    y = np.array(y)
    
    sorted_indices = np.argsort(v)
    v_sorted = v[sorted_indices]
    v_sorted2 = np.log(v_sorted/(1-v_sorted))
    y_sorted = y[sorted_indices]    
    # Define optimization variables
    z = cp.Variable(n)    
    # Define the objective function
    objective = cp.Maximize((1/n) * cp.sum((y_sorted - v_sorted) * z))
    
    # Define constraints efficiently
    constraints = [z >= -1, z <= 1]
    
    v_diff = np.abs(np.diff(v_sorted2))  # |v_i - v_(i+1)|
    z_diff = cp.abs(z[:-1] - z[1:])     # |z_i - z_(i+1)|
    constraints.append(z_diff <= v_diff/4)

    # Solve the problem with a parallel-capable solver
    prob = cp.Problem(objective, constraints)
    prob.solve(solver=cp.ECOS,verbose=False)  # ECOS_parallel

    return prob.value


def LinECE_fast(v, y):
    n = len(v)
    v = np.array(v)
    y = np.array(y)
    sorted_indices = np.argsort(v)
    v_sorted = v[sorted_indices]
    y_sorted = y[sorted_indices]
    z = cp.Variable(n)
    objective = cp.Maximize((1/n) * cp.sum((y_sorted - v_sorted) * z))
    constraints = [z >= -1, z <= 1]
    v_diff = np.abs(np.diff(v_sorted))
    z_diff = cp.abs(z[:-1] - z[1:])
    constraints.append(z_diff <= v_diff)
    prob = cp.Problem(objective, constraints)
    prob.solve(solver=cp.ECOS, verbose=False)
    return prob.value



def generate_toy_data(n_samples=20000):
    np.random.seed(42)
    class_0 = np.random.randn(n_samples, 2) * 1.2 + np.array([-1, -1])
    class_1 = np.random.randn(n_samples, 2) * 1.2 + np.array([1, 1])
    class_0[:, 0] *= 1.3
    class_1[:, 1] *= 1.3
    X = np.vstack((class_0, class_1))
    y = np.hstack((np.zeros(n_samples), np.ones(n_samples)))
    return X, y

sample_sizes = np.unique(np.logspace(np.log10(10), np.log10(10000), num=5, dtype=int))
seeds = range(10)


train_metrics_all = {label: [] for label in ["Cross Entropy", "Accuracy", "Grad Norm", "Binning ECE", "MMCE", "Smooth CE"]}#, "Kernel ECE"]}
test_metrics_all = {label: [] for label in train_metrics_all}


for seed in seeds:
    X, y = generate_toy_data()
    X_train_all, X_test_all, y_train_all, y_test_all = train_test_split(X, y, test_size=0.5, random_state=seed)

    train_results = {label: [] for label in train_metrics_all}
    test_results = {label: [] for label in test_metrics_all}

    for n in sample_sizes:
        X_sub, y_sub = resample(X_train_all, y_train_all, n_samples=n, random_state=seed)
        X_test, y_test = resample(X_test_all, y_test_all, n_samples=n, random_state=seed)

        model = NormalizedGradientBoosting(n_estimators=np.int(np.sqrt(n/(sample_sizes[0]))**2), learning_rate=0.9, max_depth=1, schedule=0.5)
        model.fit(X_sub, y_sub)

        for X_eval, y_eval, results in [(X_sub, y_sub, train_results), (X_test, y_test, test_results)]:
            probs = [i for i in model.staged_predict_proba(X_eval)][-1][:, 0]
            results["Cross Entropy"].append(log_loss(y_eval, probs))
            results["Accuracy"].append(accuracy_score(y_eval, probs > 0.5))
            results["Grad Norm"].append(np.mean(np.abs(y_eval - probs)))
            results["Binning ECE"].append(compute_binned_ece(probs, y_eval))
            #results["Kernel ECE"].append(compute_kernel_ece(probs, y_eval))
            results["MMCE"].append(np.sqrt(compute_mmce(probs, y_eval)))
            results["Smooth CE"].append(LinECE_fast(probs, y_eval))

    for key in train_metrics_all:
        train_metrics_all[key].append(train_results[key])
        test_metrics_all[key].append(test_results[key])


def mean_std(metrics_dict):
    means = {key: np.mean(metrics_dict[key], axis=0) for key in metrics_dict}
    stds = {key: np.std(metrics_dict[key], axis=0) for key in metrics_dict}
    return means, stds

train_means, train_stds = mean_std(train_metrics_all)
test_means, test_stds = mean_std(test_metrics_all)
gap_metrics_all = {
    key: np.abs(np.array(test_metrics_all[key]) - np.array(train_metrics_all[key]))
    for key in train_metrics_all
}

gap_means, gap_stds = mean_std(gap_metrics_all)


fig, axs = plt.subplots(1, 3, figsize=(24, 6))
labels = list(train_means.keys())

# Train
for label in labels:
    axs[0].errorbar(sample_sizes, train_means[label], yerr=train_stds[label], label=label, marker='o', capsize=3)
axs[0].set_title("Train Metrics",fontsize=25)
axs[0].set_xscale("log")
axs[0].set_yscale("log")
axs[0].set_xlabel("Training Sample Size",fontsize=25)
axs[0].legend(loc="lower left",fontsize=15)
axs[0].grid(True)

# Test
for label in labels:
    axs[1].errorbar(sample_sizes, test_means[label], yerr=test_stds[label], label=label, marker='s', capsize=3)
axs[1].set_title("Test Metrics",fontsize=25)
axs[1].set_xscale("log")
axs[1].set_yscale("log")
axs[1].set_xlabel("Training Sample Size",fontsize=25)
axs[1].legend(fontsize=15)
axs[1].grid(True)

# Gap
axs[2].errorbar(sample_sizes, gap_means["Cross Entropy"], yerr=gap_stds["Cross Entropy"], label="Cross Entropy Gap", marker='o', color='red')
axs[2].errorbar(sample_sizes, gap_means["Smooth CE"], yerr=gap_stds["Smooth CE"], label="Smooth CE Gap", marker='x', color='blue')
axs[2].set_title("Gap Between Test and Train",fontsize=25)
axs[2].set_xscale("log")
axs[2].set_yscale("log")
axs[2].set_xlabel("Training Sample Size",fontsize=25)
axs[2].legend(fontsize=15)
axs[2].grid(True)

#plt.suptitle("Metrics vs Training Sample Size (10 Random Seeds)", fontsize=18)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.tight_layout()

#plt.suptitle(f"Train vs Test Metrics and Gaps of Gradient boosting tree")
#plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("m1_figure6_style_metrics.eps",fontsize=25, dpi=300)