import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import log_loss, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import laplacian_kernel
import cvxpy as cp
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler

data = load_breast_cancer()


def UCI(test_size=0.7, seed=42):
    np.random.seed(seed)
    X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.7, random_state=seed)
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)    
    return (torch.tensor(X_train, dtype=torch.float32), torch.tensor(X_test, dtype=torch.float32),
            torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1), torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1))

def generate_toy_data(n_samples=200, seed=42):
    np.random.seed(seed)
    class_0 = np.random.randn(n_samples, 2) * 1.2 + np.array([-1, -1])
    class_1 = np.random.randn(n_samples, 2) * 1.2 + np.array([1, 1])
    class_0[:, 0] *= 1.3
    class_1[:, 1] *= 1.3
    X = np.vstack((class_0, class_1))
    y = np.hstack((np.zeros(n_samples), np.ones(n_samples)))
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=seed)
    return (torch.tensor(X_train, dtype=torch.float32), torch.tensor(X_test, dtype=torch.float32),
            torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1), torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1))



def generate_non_separable_ring_data(n_samples=200, seed=42):
    np.random.seed(seed)
    n_samples = n_samples // 2 * 2  

    angles = np.linspace(0, 2 * np.pi, n_samples, endpoint=False)
    radius = 1.5
    noise = 0.05

    x = radius * np.cos(angles) + np.random.normal(0, noise, size=n_samples)
    y = radius * np.sin(angles) + np.random.normal(0, noise, size=n_samples)
    X = np.stack([x, y], axis=1)

    labels = np.array([i % 2 for i in range(n_samples)])

    X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.5, random_state=seed)
    return (torch.tensor(X_train, dtype=torch.float32),
            torch.tensor(X_test, dtype=torch.float32),
            torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1),
            torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1))

def generate_unseparable_data(n_samples=200, seed=0):
    np.random.seed(seed)
    n_half = n_samples // 2
    base = np.random.randn(n_half, 2) * 0.05
    class_0 = base + np.array([[0.1, 0.1]]) + np.random.normal(0, 0.01, size=base.shape)
    class_1 = -base + np.array([[-0.1, -0.1]]) + np.random.normal(0, 0.01, size=base.shape)
    X = np.vstack([class_0, class_1])
    y = np.array([0] * n_half + [1] * n_half)
    idx = np.random.permutation(n_samples)
    X = X[idx]
    y = y[idx]
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.5, random_state=seed
    )

    return (
        torch.tensor(X_train, dtype=torch.float32),
        torch.tensor(X_test, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1),
        torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1),
    )


def compute_binned_ece(probs, y_true):
    num_bins = max(int(len(probs) ** (1/3)), 1)
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    bin_indices = np.digitize(probs, bin_boundaries, right=True) - 1
    ece = 0.0
    n = len(probs)
    for i in range(num_bins):
        bin_mask = bin_indices == i
        if np.sum(bin_mask) > 0:
            bin_acc = np.mean(y_true[bin_mask])
            bin_conf = np.mean(probs[bin_mask])
            ece += (np.sum(bin_mask) / n) * np.abs(bin_acc - bin_conf)
    return ece

def compute_kernel_ece(probs, y_true, bandwidth=0.1):
    probs = np.clip(probs, 1e-6, 1 - 1e-6)
    n = len(probs)
    ece = 0.0
    for i in range(n):
        kernel_weights = np.exp(-((probs - probs[i]) ** 2) / (2 * bandwidth**2))
        kernel_weights /= np.sum(kernel_weights)
        bin_acc = np.sum(kernel_weights * y_true)
        bin_conf = np.sum(kernel_weights * probs)
        ece += np.abs(bin_acc - bin_conf)
    return ece / n

def compute_mmce(probs, y_true, gamma=1.0):
    K = laplacian_kernel(probs.reshape(-1, 1), probs.reshape(-1, 1), gamma=gamma)
    err = probs - y_true
    mmce = np.sqrt(np.sum((err[:, None] * K * err[None, :])) / (len(probs)**2))
    return mmce

def dual_LinECE_fast(v, y):
    n = len(v)
    v = np.array(v)
    y = np.array(y)
    
    sorted_indices = np.argsort(v)
    v_sorted = v[sorted_indices]
    v_sorted2 = np.log(v_sorted/(1-v_sorted))
    y_sorted = y[sorted_indices]    
    # Define optimization variables
    z = cp.Variable(n)    
    # Define the objective function
    objective = cp.Maximize((1/n) * cp.sum((y_sorted - v_sorted) * z))
    
    # Define constraints efficiently
    constraints = [z >= -1, z <= 1]
    
    v_diff = np.abs(np.diff(v_sorted2))  # |v_i - v_(i+1)|
    z_diff = cp.abs(z[:-1] - z[1:])     # |z_i - z_(i+1)|
    constraints.append(z_diff <= v_diff/4)

    # Solve the problem with a parallel-capable solver
    prob = cp.Problem(objective, constraints)
    prob.solve(solver=cp.ECOS,verbose=False)  # ECOS_parallel

    return prob.value

def LinECE_fast(v, y):
    n = len(v)
    v = np.array(v)
    y = np.array(y)
    sorted_indices = np.argsort(v)
    v_sorted = v[sorted_indices]
    y_sorted = y[sorted_indices]
    z = cp.Variable(n)
    objective = cp.Maximize((1/n) * cp.sum((y_sorted - v_sorted) * z))
    constraints = [z >= -1, z <= 1]
    v_diff = np.abs(np.diff(v_sorted))
    z_diff = cp.abs(z[:-1] - z[1:])
    constraints.append(z_diff <= v_diff)
    prob = cp.Problem(objective, constraints)
    prob.solve(solver=cp.ECOS, verbose=False)
    return prob.value


class BinaryClassifier(nn.Module):
    def __init__(self, dim=2,hidden_units=300):  
        super(BinaryClassifier, self).__init__()
        self.fc1 = nn.Linear(dim, hidden_units)
        self.fc2 = nn.Linear(hidden_units, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
        
        with torch.no_grad():
            theta = torch.randn_like(self.fc1.weight)
            mid = hidden_units // 2
            theta[mid:] = theta[:mid]
            self.fc1.weight.copy_(theta)
            nn.init.constant_(self.fc1.bias, 0.0)
            
            a_r = torch.ones(hidden_units)
            a_r[hidden_units // 2:] = -1
            self.fc2.weight.copy_(a_r.view(1, -1))
    
    def forward(self, x):
        x = self.sigmoid(self.fc1(x))
        x = self.fc2(x)
        return self.sigmoid(x)

num_epochs = 15000
step=num_epochs//20

def run(step=step,num_epochs=num_epochs,seed=42):
    #X_train, X_test, y_train, y_test = generate_toy_data(seed=seed)
    #X_train, X_test, y_train, y_test =generate_unseparable_data(seed=seed)
    X_train, X_test, y_train, y_test = UCI(seed=seed)
    model = BinaryClassifier(dim=X_train.shape[1],hidden_units=300)
    criterion = nn.BCELoss()
    eta = 0.001   
    metrics_train, metrics_test = [], []

    grad_norms_train = []
    cumulative_avg_grad_train = []
    grad_norms_train_yminusout = []
    cumulative_avg_grad_train_yminusout = []

    for epoch in range(num_epochs):
        model.zero_grad()
        output_train = model(X_train).flatten()
        loss = criterion(output_train, y_train.flatten())
        loss.backward()

        total_norm = 0.0
        with torch.no_grad():
            for param in model.parameters():
                if param.grad is not None:
                    total_norm += param.grad.norm().item() ** 2
                    param -= eta * param.grad

        if epoch % step == 0:
            grad_norms_train.append(np.sqrt(total_norm))
            cumulative_avg_grad_train.append(np.mean(grad_norms_train))

            grad_yout = (y_train.flatten() - output_train).detach().numpy()
            grad_norms_train_yminusout.append(np.mean(np.abs(grad_yout))**2)
            cumulative_avg_grad_train_yminusout.append(np.mean(grad_norms_train_yminusout))

            def evaluate(X, y):
                out = model(X).flatten().detach().numpy()
                pred = (out > 0.5).astype(int)
                grad = y.numpy().flatten() - out
                return {
                                    "Cross Entropy": log_loss(y.numpy(), out),
                                    "Accuracy": accuracy_score(y.numpy(), pred),
                                    "Grad Norm": np.mean(np.abs(grad)),
                                    "Binning ECE": compute_binned_ece(out, y.numpy().flatten()),
                                    "kenelECE": compute_kernel_ece(out, y.numpy().flatten()),
                                    "Smooth CE": LinECE_fast(out, y.numpy().flatten()),
                                    "MMCE": compute_mmce(out, y.numpy().flatten())
                                }

            metrics_train.append(evaluate(X_train, y_train))
            metrics_test.append(evaluate(X_test, y_test))

    return metrics_train, metrics_test, cumulative_avg_grad_train, cumulative_avg_grad_train_yminusout, model, X_train, y_train


def collect_metrics_across_seeds(step=step,num_epochs=num_epochs,seed=42):
    all_train, all_test = [], []
    all_cumgrad1, all_cumgrad2 = [], []
    model, X_sample, y_sample = None, None, None

    for i, seed in enumerate(seeds):
        train, test, cum1, cum2, m, X, y = run(step=step,num_epochs=num_epochs,seed=seed)
        all_train.append(train)
        all_test.append(test)
        all_cumgrad1.append(cum1)
        all_cumgrad2.append(cum2)
        if i == 0:
            model, X_sample, y_sample = m, X, y

    def aggregate(metric):
        train_arr = np.array([[m[metric] for m in run] for run in all_train])
        test_arr = np.array([[m[metric] for m in run] for run in all_test])
        train_mean = np.mean(train_arr, axis=0)
        test_mean = np.mean(test_arr, axis=0)
        train_std = np.std(train_arr, axis=0)
        test_std = np.std(test_arr, axis=0)
        gap_mean = np.abs(test_mean - train_mean)
        gap_std = np.std(np.abs(test_arr - train_arr), axis=0)
        return train_mean, train_std, test_mean, test_std, gap_mean, gap_std

    def aggregate_cumgrad():
        return [
            np.mean(np.array(all_cumgrad1), axis=0),
            np.mean(np.array(all_cumgrad2), axis=0)
        ], [
            np.std(np.array(all_cumgrad1), axis=0),
            np.std(np.array(all_cumgrad2), axis=0)
        ]

    return (aggregate, aggregate_cumgrad), model, X_sample, y_sample




def plot_with_errorbars(aggregator_pair, step=step,metrics=["Cross Entropy", "Accuracy", "Grad Norm", "Binning ECE", "Smooth CE", "MMCE"]):
    aggregator, aggregate_cumgrad = aggregator_pair
    cumgrad_mean, cumgrad_std = aggregate_cumgrad()
    x = np.arange(len(cumgrad_mean[0])) * step
    fig, axs = plt.subplots(1, 3, figsize=(24, 6))
    # Train metrics
    for m in metrics:
        train_mean, train_std, *_ = aggregator(m)
        axs[0].errorbar(x, train_mean, yerr=train_std, label=m, capsize=3)
    #axs[0].errorbar(x, cumgrad_mean[0], yerr=cumgrad_std[0], label="Cum Pgrad", linestyle="--")
    #axs[0].errorbar(x, cumgrad_mean[1], yerr=cumgrad_std[1], label="Cum Fgrad", linestyle=":")
    axs[0].set_title("Train Metrics", fontsize=25)
    axs[0].set_yscale("log")
    axs[0].set_xlabel("Iteration", fontsize=25)
    axs[0].legend(loc="lower left", fontsize=15)
    axs[0].grid()
    # Test metrics
    for m in metrics:
        *_, test_mean, test_std, _, _ = aggregator(m)
        axs[1].errorbar(x, test_mean, yerr=test_std, label=m, capsize=3)
    axs[1].set_title("Test Metrics", fontsize=25)
    axs[1].set_yscale("log")
    axs[1].set_xlabel("Iteration", fontsize=25)
    axs[1].legend(loc="upper right", fontsize=15)
    axs[1].grid()
    # Gap metrics
    for m in ["Cross Entropy", "Smooth CE"]:
        *_, gap_mean, gap_std = aggregator(m)
        axs[2].errorbar(x, gap_mean, yerr=gap_std, label=f"{m} gap", capsize=3)
    axs[2].set_title("Gap Between Test and Train", fontsize=25)
    axs[2].set_xlabel("Iteration", fontsize=25)
    #axs[2].set_yscale("log")
    axs[2].legend(fontsize=15)
    axs[2].grid()
    plt.tight_layout()
    plt.savefig("toy_NN_iteration_unit300.eps", dpi=300)


def plot_decision_boundary(model, X, y):
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100))
    grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
    preds = model(grid).detach().numpy().reshape(xx.shape)
    
    plt.contourf(xx, yy, preds, levels=[0, 0.5, 1], alpha=0.3)
    plt.scatter(X[:, 0], X[:, 1], c=y.flatten(), cmap='coolwarm', edgecolors='k')
    plt.title("Decision Boundary")
    plt.show()

seeds = range(10)
aggregator_pair, model, X_sample, y_sample = collect_metrics_across_seeds(step=step,num_epochs=num_epochs,seed=seeds)
plot_with_errorbars(aggregator_pair,step=step)
#plot_decision_boundary(model, X_sample.numpy(), y_sample.numpy())