import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss, accuracy_score
from sklearn.metrics.pairwise import laplacian_kernel
from scipy.special import expit
import cvxpy as cp


def compute_binned_ece(probs, y_true, num_bins=10):
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    bin_indices = np.digitize(probs, bin_boundaries, right=True) - 1
    ece = 0.0
    n = len(probs)
    for i in range(num_bins):
        bin_mask = bin_indices == i
        bin_size = np.sum(bin_mask)
        if bin_size > 0:
            bin_acc = np.mean(y_true[bin_mask])
            bin_conf = np.mean(probs[bin_mask])
            bin_weight = bin_size / n
            ece += bin_weight * np.abs(bin_acc - bin_conf)
    return ece

def compute_kernel_ece(probs, y_true, bandwidth=0.1):
    probs = np.clip(probs, 1e-6, 1 - 1e-6)
    n = len(probs)
    smooth_ece = 0.0
    for i in range(n):
        kernel_weights = np.exp(-((probs - probs[i]) ** 2) / (2 * bandwidth**2))
        kernel_weights /= np.sum(kernel_weights)
        bin_acc = np.sum(kernel_weights * y_true)
        bin_conf = np.sum(kernel_weights * probs)
        smooth_ece += np.abs(bin_acc - bin_conf)
    return smooth_ece / n

def compute_mmce(probs, y_true, gamma=1.0):
    n = len(probs)
    K = laplacian_kernel(probs.reshape(-1, 1), probs.reshape(-1, 1), gamma=gamma)
    calibration_error = (probs - y_true)[:, np.newaxis] * K * (probs - y_true)[np.newaxis, :]
    mmce = np.sqrt(np.sum(calibration_error) / (n ** 2))
    return mmce

def dual_LinECE_fast(v, y):
    n = len(v)
    v = np.array(v)
    y = np.array(y)
    sorted_indices = np.argsort(v)
    v_sorted = v[sorted_indices]
    v_sorted2 = np.log(v_sorted / (1 - v_sorted))
    y_sorted = y[sorted_indices]
    z = cp.Variable(n)
    objective = cp.Maximize((1/n) * cp.sum((y_sorted - v_sorted) * z))
    constraints = [z >= -1, z <= 1]
    v_diff = np.abs(np.diff(v_sorted2))
    z_diff = cp.abs(z[:-1] - z[1:])
    constraints.append(z_diff <= v_diff / 4)
    prob = cp.Problem(objective, constraints)
    prob.solve(solver=cp.ECOS, verbose=False)
    return prob.value

def LinECE_fast(v, y):
    n = len(v)
    v = np.array(v)
    y = np.array(y)
    sorted_indices = np.argsort(v)
    v_sorted = v[sorted_indices]
    y_sorted = y[sorted_indices]
    z = cp.Variable(n)
    objective = cp.Maximize((1/n) * cp.sum((y_sorted - v_sorted) * z))
    constraints = [z >= -1, z <= 1]
    v_diff = np.abs(np.diff(v_sorted))
    z_diff = cp.abs(z[:-1] - z[1:])
    constraints.append(z_diff <= v_diff)
    prob = cp.Problem(objective, constraints)
    prob.solve(solver=cp.ECOS, verbose=False)
    return prob.value



def generate_toy_data(n_samples=200):
    class_0 = np.random.randn(n_samples, 2) * 1.2 + np.array([-1, -1])
    class_1 = np.random.randn(n_samples, 2) * 1.2 + np.array([1, 1])
    class_0[:, 0] *= 1.3
    class_1[:, 1] *= 1.3
    X = np.vstack((class_0, class_1))
    y = np.hstack((np.zeros(n_samples), np.ones(n_samples)))
    return X, y


class NormalizedGradientBoosting:
    def __init__(self, n_estimators=5000, learning_rate=0.9, max_depth=3, schedule=0.5):
        self.n_estimators = n_estimators
        self.learning_rate = learning_rate
        self.max_depth = max_depth
        self.models = []
        self.alphas = []
        self.beta = schedule
    
    def fit(self, X, y):
        n_samples = X.shape[0]
        F = np.zeros(n_samples)
        for t in range(1, self.n_estimators + 1):
            F = np.clip(F, -20, 20)
            p = expit(F)
            residuals = y - p
            model = DecisionTreeRegressor(max_depth=self.max_depth)
            model.fit(X, residuals)
            alpha = self.learning_rate / (t + 2)**self.beta
            F = (F + alpha * (model.predict(X) - F)) / (1 - alpha)
            self.models.append(model)
            self.alphas.append(alpha)
    
    def staged_predict_proba(self, X):
        F = np.zeros(X.shape[0])
        for model, alpha in zip(self.models, self.alphas):
            F = (F + alpha * (model.predict(X) - F)) / (1 - alpha)
            p = expit(F)
            yield np.vstack([p, 1 - p]).T


NUM_SEEDS = 10
INTERVAL = 50
MAX_ITER = 5000
metrics = {k: [] for k in [
    "entropies_train", "entropies_test", "accuracies_train", "accuracies_test",
    "binned_eces_train", "binned_eces_test", "kernel_eces_train", "kernel_eces_test",
    "mmce_values_train", "mmce_values_test", "linear_ece_train", "linear_ece_test",
    "F_gradient_norms_train", "F_gradient_norms_test" 
]}

for seed in range(NUM_SEEDS):
    np.random.seed(seed)
    X, y = generate_toy_data()
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=seed)
    gb = NormalizedGradientBoosting(max_depth=3)
    gb.fit(X_train, y_train)
    tmp = {k: [] for k in metrics.keys()}
    
    for i, (y_probs_train, y_probs_test) in enumerate(zip(gb.staged_predict_proba(X_train), gb.staged_predict_proba(X_test))):
        if i % INTERVAL == 0:
            for tag, probs, y_true in [("train", y_probs_train, y_train), ("test", y_probs_test, y_test)]:
                pos_probs = probs[:, 0]
                tmp[f'entropies_{tag}'].append(log_loss(y_true, pos_probs))
                tmp[f'accuracies_{tag}'].append(accuracy_score(y_true, (pos_probs > 0.5).astype(int)))
                tmp[f'binned_eces_{tag}'].append(compute_binned_ece(pos_probs, y_true))
                tmp[f'kernel_eces_{tag}'].append(compute_kernel_ece(pos_probs, y_true))
                tmp[f'mmce_values_{tag}'].append(compute_mmce(pos_probs, y_true))
                tmp[f'linear_ece_{tag}'].append(LinECE_fast(pos_probs, y_true))
                F_grad = y_true - pos_probs
                tmp[f'F_gradient_norms_{tag}'].append(np.mean(np.abs(F_grad)))
    for k in metrics:
        metrics[k].append(tmp[k])


x = np.arange(0, MAX_ITER, INTERVAL)

fig, axs = plt.subplots(1, 3, figsize=(24, 6))

def plot_with_errorbar(ax, y_vals, label, marker=None, linestyle=None, color=None):
    y_vals = np.array(y_vals)
    mean = np.mean(y_vals, axis=0)
    std = np.std(y_vals, axis=0)
    ax.errorbar(x, mean, yerr=std, label=label, marker=marker, linestyle=linestyle, capsize=3, color=color)

# --- Train ---
plot_with_errorbar(axs[0], metrics["entropies_train"], "Cross Entropy", marker='o')
plot_with_errorbar(axs[0], metrics["accuracies_train"], "Accuracy", marker='d')
plot_with_errorbar(axs[0], metrics["F_gradient_norms_train"], "Grad Norm", linestyle='dashed', color='blue')
plot_with_errorbar(axs[0], metrics["binned_eces_train"], "Binning ECE", linestyle='dashed', color='orange')
#plot_with_errorbar(axs[0], metrics["kernel_eces_train"], "Kernel ECE", linestyle='dashed', color='gray')
plot_with_errorbar(axs[0], metrics["mmce_values_train"], "MMCE", linestyle='dashed', color='purple')
plot_with_errorbar(axs[0], metrics["linear_ece_train"], "Smooth CE", linestyle='dashed', color='black')

axs[0].set_title("Train Metrics", fontsize=25)
axs[0].set_xlabel("Iteration", fontsize=25)
axs[0].set_yscale("log")
#axs[0].set_xscale("log")
axs[0].legend(loc='upper right', fontsize=15)
axs[0].grid(True)

# --- Test ---
plot_with_errorbar(axs[1], metrics["entropies_test"], "Cross Entropy", marker='o')
plot_with_errorbar(axs[1], metrics["accuracies_test"], "Accuracy", marker='d')
plot_with_errorbar(axs[1], metrics["F_gradient_norms_test"], "Grad Norm", linestyle='dashed', color='blue')
plot_with_errorbar(axs[1], metrics["binned_eces_test"], "Binning ECE", linestyle='dashed', color='orange')
#plot_with_errorbar(axs[1], metrics["kernel_eces_test"], "Kernel ECE", linestyle='dashed', color='gray')
plot_with_errorbar(axs[1], metrics["mmce_values_test"], "MMCE", linestyle='dashed', color='purple')
plot_with_errorbar(axs[1], metrics["linear_ece_test"], "Smooth CE", linestyle='dashed', color='black')

axs[1].set_title("Test Metrics", fontsize=25)
axs[1].set_xlabel("Iteration", fontsize=25)
axs[1].set_yscale("log")
#axs[1].set_xscale("log")
axs[1].legend(fontsize=15)
axs[1].grid(True)

# --- Gap ---
smooth_gap = np.abs(np.mean(metrics["linear_ece_test"], axis=0) - np.mean(metrics["linear_ece_train"], axis=0))
ce_gap = np.abs(np.mean(metrics["entropies_test"], axis=0) - np.mean(metrics["entropies_train"], axis=0))
axs[2].errorbar(x, ce_gap, label="Cross Entropy Gap", marker='o', color='red', capsize=3)
axs[2].errorbar(x, smooth_gap, label="Smooth CE Gap", marker='x', color='blue', capsize=3)

axs[2].set_title("Gap Between Test and Train", fontsize=25)
axs[2].set_xlabel("Iteration", fontsize=25)
axs[2].set_yscale("log")
#axs[2].set_xscale("log")
axs[2].legend(fontsize=15)
axs[2].grid(True)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("figure2_style_metrics_avg.eps", dpi=300)
