import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import log_loss, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import laplacian_kernel
import cvxpy as cp


def generate_non_separable_ring_data(n_samples=1000, seed=42):
    np.random.seed(seed)
    angles = np.linspace(0, 2*np.pi, n_samples, endpoint=False)
    radius_base = 1.5
    delta = np.random.uniform(0.05, 0.2)
    labels = np.array([i % 2 for i in range(n_samples)])
    x = np.empty(n_samples)
    y = np.empty(n_samples)
    for i, angle in enumerate(angles):
        label = labels[i]
        if label == 0:
            r = radius_base - delta if np.random.rand() < 0.5 else radius_base + delta
        else:
            r = radius_base
        x[i] = r * np.cos(angle) + np.random.normal(0, 0.05)
        y[i] = r * np.sin(angle) + np.random.normal(0, 0.05)
    X = np.stack([x, y], axis=1)
    idx = np.random.permutation(n_samples)
    X, labels = X[idx], labels[idx]
    X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.5, random_state=seed)
    return (torch.tensor(X_train, dtype=torch.float32),
            torch.tensor(X_test, dtype=torch.float32),
            torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1),
            torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1))

def generate_unseparable_data(n_samples=200, seed=0):
    np.random.seed(seed)
    # Half samples
    n_half = n_samples // 2
    base = np.random.randn(n_half, 2) * 0.05
    # Class 0
    class_0 = base + np.array([[0.1, 0.1]]) + np.random.normal(0, 0.01, size=base.shape)
    # Class 1
    class_1 = -base + np.array([[-0.1, -0.1]]) + np.random.normal(0, 0.01, size=base.shape)
    
    X = np.vstack([class_0, class_1])
    y = np.array([0] * n_half + [1] * n_half)
    
    idx = np.random.permutation(n_samples)
    X = X[idx]
    y = y[idx]
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.5, random_state=seed
    )
    return (
        torch.tensor(X_train, dtype=torch.float32),
        torch.tensor(X_test, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1),
        torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1),
    )
def generate_toy_data(n_samples=200, seed=42):
    np.random.seed(seed)
    class_0 = np.random.randn(n_samples, 2) * 1.2 + np.array([-1, -1])
    class_1 = np.random.randn(n_samples, 2) * 1.2 + np.array([1, 1])
    class_0[:, 0] *= 1.3
    class_1[:, 1] *= 1.3
    X = np.vstack((class_0, class_1))
    y = np.hstack((np.zeros(n_samples), np.ones(n_samples)))
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=seed)
    return (torch.tensor(X_train, dtype=torch.float32), torch.tensor(X_test, dtype=torch.float32),
            torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1), torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1))

def compute_binned_ece(probs, y_true):
    num_bins = max(int(len(probs) ** (1/3)), 1)
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    bin_indices = np.digitize(probs, bin_boundaries, right=True) - 1
    ece = 0.0
    n = len(probs)
    for i in range(num_bins):
        bin_mask = bin_indices == i
        if np.sum(bin_mask) > 0:
            bin_acc = np.mean(y_true[bin_mask])
            bin_conf = np.mean(probs[bin_mask])
            ece += (np.sum(bin_mask) / n) * np.abs(bin_acc - bin_conf)
    return ece

def compute_kernel_ece(probs, y_true, bandwidth=0.1):
    probs = np.clip(probs, 1e-6, 1 - 1e-6)
    n = len(probs)
    ece = 0.0
    for i in range(n):
        kernel_weights = np.exp(-((probs - probs[i]) ** 2) / (2 * bandwidth**2))
        kernel_weights /= np.sum(kernel_weights)
        bin_acc = np.sum(kernel_weights * y_true)
        bin_conf = np.sum(kernel_weights * probs)
        ece += np.abs(bin_acc - bin_conf)
    return ece / n

def compute_mmce(probs, y_true, gamma=1.0):
    K = laplacian_kernel(probs.reshape(-1, 1), probs.reshape(-1, 1), gamma=gamma)
    err = probs - y_true
    mmce = np.sqrt(np.sum((err[:, None] * K * err[None, :])) / (len(probs)**2))
    return mmce

def LinECE_fast(v, y):
    n = len(v)
    v = np.array(v)
    y = np.array(y)
    sorted_indices = np.argsort(v)
    v_sorted = v[sorted_indices]
    #v_sorted2 = np.log(v_sorted / (1 - v_sorted))
    y_sorted = y[sorted_indices]    
    z = cp.Variable(n)
    objective = cp.Maximize((1/n) * cp.sum((y_sorted - v_sorted) * z))
    constraints = [z >= -1, z <= 1]
    v_diff = np.abs(np.diff(v_sorted))
    z_diff = cp.abs(z[:-1] - z[1:])
    constraints.append(z_diff <= v_diff)
    prob = cp.Problem(objective, constraints)
    prob.solve(solver=cp.ECOS, verbose=False)
    return prob.value


class BinaryClassifier(nn.Module):
    def __init__(self, hidden_units=6):
        super(BinaryClassifier, self).__init__()
        self.fc1 = nn.Linear(2, hidden_units)
        self.fc2 = nn.Linear(hidden_units, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
        with torch.no_grad():
            theta = torch.randn_like(self.fc1.weight)
            mid = hidden_units // 2
            theta[mid:] = theta[:mid]
            self.fc1.weight.copy_(theta)
            nn.init.constant_(self.fc1.bias, 0.0)
            a_r = torch.ones(hidden_units)
            a_r[hidden_units // 2:] = -1
            self.fc2.weight.copy_(a_r.view(1, -1))
    def forward(self, x):
        x = self.sigmoid(self.fc1(x))
        x = self.fc2(x)
        return self.sigmoid(x)


def run_with_sample_size(n_samples, seed=42):
    #X_train, X_test, y_train, y_test = generate_toy_data(n_samples=n_samples, seed=seed)
    X_train, X_test, y_train, y_test = generate_unseparable_data(n_samples=n_samples,seed=seed)
    model = BinaryClassifier(hidden_units=300)
    criterion = nn.BCELoss()
    eta = 0.01
    num_epochs = np.int32(100*(n_samples/10)**1)

    best_loss = float('inf')
    patience_counter = 0
    early_stop_patience = 20
    tol = 1e-4

    for epoch in range(num_epochs):
        model.zero_grad()
        output_train = model(X_train).flatten()
        loss = criterion(output_train, y_train.flatten())
        loss.backward()

        
        current_loss = loss.item()
        if best_loss - current_loss > tol:
            best_loss = current_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stop_patience:
                print(epoch)
                break

        with torch.no_grad():
            for param in model.parameters():
                if param.grad is not None:
                    param -= eta * param.grad
    def evaluate(X, y):
        out = model(X).flatten().detach().numpy()
        pred = (out > 0.5).astype(int)
        grad = y.numpy().flatten() - out
        return {
            "Cross Entropy": log_loss(y.numpy(), out),
            "Accuracy": accuracy_score(y.numpy(), pred),
            "Grad Norm": np.mean(np.abs(grad)),
            "Binning ECE": compute_binned_ece(out, y.numpy().flatten()),
            #"Kernel ECE": compute_kernel_ece(out, y.numpy().flatten()),
            "MMCE": compute_mmce(out, y.numpy().flatten()),
            "Smooth CE": LinECE_fast(out, y.numpy().flatten()),
        }
    return evaluate(X_train, y_train), evaluate(X_test, y_test)

if __name__ == "__main__":
    sample_sizes = np.unique(np.logspace(np.log10(10), np.log10(10000), num=5, dtype=int))
    seeds = np.arange(10)
    metrics = ["Cross Entropy", "Accuracy","Grad Norm", "Binning ECE", "MMCE", "Smooth CE"]
    results_train = {m: [] for m in metrics}
    results_test = {m: [] for m in metrics}
    for N in sample_sizes:
        tmp_train = {m: [] for m in metrics}
        tmp_test = {m: [] for m in metrics}
        for seed in seeds:
            train_metrics, test_metrics = run_with_sample_size(N, seed)
            for m in metrics:
                tmp_train[m].append(train_metrics[m])
                tmp_test[m].append(test_metrics[m])
        for m in metrics:
            results_train[m].append((np.mean(tmp_train[m]), np.std(tmp_train[m])))
            results_test[m].append((np.mean(tmp_test[m]), np.std(tmp_test[m])))

    
    fig, axs = plt.subplots(1, 3, figsize=(24, 6))

    # Train
    for m in metrics:
        mean, std = zip(*results_train[m])
        axs[0].errorbar(sample_sizes, mean, yerr=std, label=m, marker='o')
    axs[0].set_title("Train Metrics", fontsize=20)
    axs[0].set_xlabel("Sample Size N", fontsize=15)
    axs[0].set_xscale("log")
    axs[0].set_yscale("log")
    axs[0].grid()
    axs[0].legend()

    # Test
    for m in metrics:
        mean, std = zip(*results_test[m])
        axs[1].errorbar(sample_sizes, mean, yerr=std, label=m, marker='s')
    axs[1].set_title("Test Metrics", fontsize=20)
    axs[1].set_xlabel("Sample Size N", fontsize=15)
    axs[1].set_xscale("log")
    axs[1].set_yscale("log")
    axs[1].grid()
    axs[1].legend()

    # Gap
    for m in ["Cross Entropy", "Smooth CE"]:
        train_mean, train_std = zip(*results_train[m])
        test_mean, test_std = zip(*results_test[m])
        gap = np.abs(np.array(test_mean) - np.array(train_mean))
        train_vals = [v[0] for v in results_train[m]]
        test_vals = [v[0] for v in results_test[m]]
        gap_std = np.std([abs(t - s) for t, s in zip(test_vals, train_vals)], axis=0)
        axs[2].errorbar(sample_sizes, gap, yerr=gap_std, label=f"{m} Gap", marker='^')
    axs[2].set_title("Gap Between Test and Train", fontsize=20)
    axs[2].set_xlabel("Sample Size N", fontsize=15)
    axs[2].set_xscale("log")
    axs[2].set_yscale("log")
    axs[2].grid()
    axs[2].legend()

    plt.tight_layout()
    plt.savefig("NN_sample_size_sweep_unsep.eps",dpi=300)