import warnings 
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=FutureWarning)

from sklearn.svm import SVC
from sklearn.svm import LinearSVC
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap, to_rgba
import random
from sklearn.linear_model import LogisticRegression

def train_soft_svm(x, y, v=None, c=1.0, loss = 'hinge', fit_intercept=True):
    """
    Train a linear SVM with:
      - bias fixed to zero (fit_intercept = False)
      - hinge loss averaged over sum(v) instead of N
    """
    
    # If no sample weights provided, use all weights = 1
    if v is None:
        v = np.ones(len(y))
    if loss in ['rbf1', 'rbf2', 'rbf3']:
        gamma_scale = 1 / (x.shape[1] * x.var()) 
    
    if loss == 'hinge':
        model = LinearSVC(
            C=c,
            loss='hinge',
            max_iter=2000, 
            fit_intercept=fit_intercept,   # <--- forces b = 0
            dual=True,
            random_state=0,
        )
    elif loss == 'log':
        model = LogisticRegression(
        penalty='l2',
        C=c,
        solver='liblinear',
        max_iter=2000, 
        fit_intercept=fit_intercept,
        dual=False,
        random_state=0 
        )
    elif loss == 'squared_hinge':
        model = LinearSVC(
            C=c,
            loss='squared_hinge',
            fit_intercept=fit_intercept,   # <--- forces b = 0
            dual=True,
            random_state=0
        )
    elif loss == 'poly2':
        model = SVC(
            C=c,
            kernel = 'poly', 
            coef0=1, degree=2, 
            gamma = 'auto', 
            max_iter = 600000,
            random_state=0
            )
    elif loss == 'poly3':
        model = SVC(
            C=c,
            max_iter=1000000,
            kernel = 'poly', 
            coef0=1, degree=3, 
            gamma = 'auto', 
            random_state=0
            )
    # elif loss == 'poly4':
    #     model = SVC(
    #         C=c,
    #         max_iter=500000,
    #         kernel = 'poly', 
    #         coef0=1, degree=4, 
    #         gamma = 'auto', 
    #         random_state=0
    #         )
    elif loss == 'rbf1':
        model = SVC(
        C=c,
        kernel='rbf',
        gamma= 2 *gamma_scale, 
        max_iter=10000,
        random_state=0
        )
    elif loss == 'rbf2':
        model = SVC(
        C=c,
        kernel='rbf',
        gamma= gamma_scale,  
        max_iter=10000,
        random_state=0
    )
    elif loss == 'rbf3':
        model = SVC(
        C=c,
        kernel='rbf',
        gamma= 0.5 *gamma_scale,  
        max_iter=10000,
        random_state=0
    )

    model.fit(x, y, sample_weight=v)
    return model

def get_relevant_indices(svm_model, x, y, v, c, sigma_loss = 1.0, k=None, k_coef=1.0, tol = 0.05):
    # unpack model
    # w = svm_model.coef_.ravel()
    # b = svm_model.intercept_

    # # decision values and margins
    # decision_values = x @ w + b
    # margin = y * decision_values

    decision_values = svm_model.decision_function(x)
    margin = y * decision_values


    # correctly classified *on-margin* support vectors: 0 < y*f(x) <= 1
    support_vector_mask = (margin > 0) & (margin <= 1)
    support_vector_indices = np.where(support_vector_mask)[0]

    # extract V for those points
    V_critical = v[support_vector_indices]

    # lambda = 1/C
    lam = 1.0 / c

    # compute k if not provided: max L2 norm of x
    if k is None:
        # print(f'k is none, computing from data')
        k = np.linalg.norm(x, axis=1).max()


    # compute beta values (vectorized)
    beta_values = ((k*k_coef)**2 * V_critical * sigma_loss) / (2 * lam)


    # select: margin < beta   (vectorized)
    selected_mask = margin[support_vector_indices] <  beta_values + tol
    relevant_indices = support_vector_indices[selected_mask]
    
    #print("margin, beta", support_vector_indices , margin[support_vector_indices] ,  beta_values[0])
    return relevant_indices, support_vector_indices

def get_throw(svm_model, x, y, v, c, relevant_indices, sigma_loss = 1.0, k=None, k_coef=1.0, tol = 0.1):
    # unpack model
    decision_values = svm_model.decision_function(x)
    margin = y * decision_values

    # points outside the margin: y*f(x) > 1
    outside_margin_mask = margin > 1 
    outside_margin_indices = np.where(outside_margin_mask)[0]
   
    # extract V for those points
    V_relevant = v[relevant_indices]
    
    # lambda = 1/C
    lam = 1.0 / c

    # compute k if not provided: max L2 norm of x
    if k is None:
        k = np.linalg.norm(x, axis=1).max()

    # compute beta values (vectorized)
    max_relevant_beta = ((k*k_coef)**2 * max(V_relevant) * sigma_loss) / (2 * lam)

    # final outside selection: margin > 1 + beta max
    selected_mask = margin[outside_margin_indices] > (1 + 2*max_relevant_beta) + tol
    throw = outside_margin_indices[selected_mask]

    return throw 

def update_indices(relevant_indices, removed_indices):
    if len(removed_indices) == 0:
        return relevant_indices
    
    removed = sorted(removed_indices)
    updated = []

    for idx in relevant_indices:
        shift = sum(r < idx for r in removed)
        updated.append(idx - shift)

    return updated

def run_random_mechanism(x_data, y, v, c, use_loss, mu=0.5, fit_intercept=True, k=None, random_state = 0):
    n = len(v)
    rng = np.random.default_rng(random_state)

    # ---- Step 1: sample chi_i ----
    gamma = rng.uniform(0, 1, size=n)

    chi = np.ones(n)

    resample_mask = rng.random(n) < mu

    chi[resample_mask] = gamma[resample_mask] ** (1.0 / (1.0 - mu))

    # ---- Step 2: modified bids ----
    x = chi * v

    # ---- Step 3: allocation ----
    model = train_soft_svm(
        x_data,
        y,
        v=x,
        c=c,
        loss=use_loss,
        fit_intercept=fit_intercept
    )

    preds = model.predict(x_data)

    # ---- Step 4: allocation = correctness indicator ----
    allocation = (preds == y).astype(float)

    # ---- Step 5: payments ----
    payments = np.zeros(n)
    records = []

    for i in range(n):
        if chi[i] == 1:
            factor = 1
        else:
            factor = 1 - (1 / mu)

        payments[i] = v[i] * allocation[i] * factor

        records.append({"agent": i,
            "allocation": allocation[i],
            "true_v": v[i],
            #"modfied_v": x[i],
            "critical_v": payments[i],
            "welfare": v[i] * allocation[i],
            "utility": v[i] * allocation[i] - payments[i],
            "times_alloc_called": 1/len(y)})
            # "is_relevant": int(is_relevant),
            # "support": int(is_support) })
            
    df_exact = pd.DataFrame(records)
    return df_exact, model 

def run_svm_payment(
    x, y, v, c, use_loss,
    sigma_loss=1.0, plot=False,
    is_throw=False, k=None, k_coef=1.0,
    fit_intercept=False,
    payment_mode="exact",  # "exact" or "random"
    n_random=1,  # number of draws for expectation if payment_mode="random"
    random_state=0
):
    rng = np.random.default_rng(random_state)

    # Train main SVM
    svm_model = train_soft_svm(x, y, v, c=c, loss=use_loss, fit_intercept=fit_intercept)

    # Get relevant indices & support
    relevant_indices, support_idx = get_relevant_indices(
        svm_model, x, y, v, c, sigma_loss=sigma_loss, k=k, k_coef=k_coef
    )


    throw = get_throw(svm_model, x, y, v, c, relevant_indices, sigma_loss=sigma_loss, k=k, k_coef=k_coef) \
        if is_throw and len(relevant_indices) > 0 else []

    n = len(v)
    throw_set = set(throw)
    relevant_set = set(relevant_indices)
    support_set = set(support_idx)

    mask = np.ones(n, dtype=bool)
    mask[throw] = False
    new_x, new_y, new_v = x[mask], y[mask], v[mask]

    updated_relevant_indices = update_indices(relevant_indices, throw)
    real_to_updated = dict(zip(relevant_indices, updated_relevant_indices))

    records = []

    for target_idx in range(n):
        is_relevant = target_idx in relevant_set
        is_support = target_idx in support_set
        is_thrown = target_idx in throw_set

        if is_relevant:
            if payment_mode == "exact":
                updated_idx = real_to_updated[target_idx]
                critical_v, alloc, counter = compute_critical_bid(
                    new_x, new_y, new_v, updated_idx,
                    train_soft_svm,
                    loss=use_loss,
                    plot=plot,
                    c=c,
                    fit_intercept=fit_intercept
                )
                
            
            elif payment_mode == "random":
                # Compute expected allocation and payment over n_random draws
                allocs, critical_vs = [], []
                for _ in range(n_random):
                    v_temp = new_v.copy()
                    v_temp[target_idx] = rng.uniform(0, v[target_idx])  # random draw
                    model_u = train_soft_svm(new_x, new_y, v_temp, c=c, loss=use_loss, fit_intercept=fit_intercept)
                    pred = model_u.predict(x[target_idx].reshape(1, -1))
                    alloc = int(pred == y[target_idx])
                    allocs.append(alloc)
                    critical_vs.append(0 if alloc == 1 else v[target_idx])
                alloc = np.mean(allocs)
                critical_v = np.mean(critical_vs)
                counter = is_relevant * n_random 
            else:
                raise ValueError("Unknown payment_mode")
        else:
            alloc = int(svm_model.predict(x[target_idx].reshape(1, -1)) == y[target_idx])
            critical_v = 0

        records.append({
            "agent": target_idx,
            "allocation": alloc,
            "true_v": v[target_idx],
            "critical_v": critical_v,
            "welfare": v[target_idx] * alloc,
            "utility": v[target_idx] * alloc - critical_v,
            "times_alloc_called": counter + 1/len(y)
            # "is_relevant": int(is_relevant),
            # "support": int(is_support),
            #"is_throw": int(is_thrown),
        })

    df = pd.DataFrame(records)
    return df, svm_model


def check_and_plot(x, y, v_mod, target_idx, allocation_rule, 
                   c=1, loss = 'hinge', plot = True, fit_intercept=True):
    model = allocation_rule(x, y, v=v_mod, c=c, loss = loss, fit_intercept=fit_intercept)
    predictions = model.predict(x)
    alloc = int(predictions[target_idx] == y[target_idx])
    if plot:
        plot_svm_decision_boundary(model, x, y, v=v_mod,
                                   title=f"Decision Boundary @ v={v_mod[target_idx]:.5f}, target = {target_idx}",
                                   target_idx=target_idx)
    return alloc, model

def compute_critical_bid(x, y, v, target_idx, allocation_rule, loss = 'hinge', tol=1e-10, max_iter=100, c=1,
                          plot = True, fit_intercept=True):
    """
    Use binary search to find the minimal v in [0, max(v)] for which the allocation a = 1
    for a given target index. Plot at v=0, v=max, and every `plot_every` steps.
    """
    v_coef = 1

    v_mod = v.copy()
    min_v, max_v = 0.0, v_coef * v[target_idx] 
    # print('max-', max_v)
    
    # Early check at v = 0
    v_mod[target_idx] = min_v
    alloc_0, model_0 = check_and_plot(x, y, v_mod, target_idx, allocation_rule, loss = 'hinge', c=c, 
                                      plot = plot, fit_intercept=fit_intercept)
    if alloc_0 == 1:
        #print("allocation is 1 at v=0")
        return min_v, alloc_0 , 1
    
    #no need for early check at v_max, all points entering here have alloc = 1

    # Binary search
    low, high = min_v, max_v
    alloc_mid = 0  # Initialize in case loop does not run
    model_mid = None
    critical_v = None
    for i in range(max_iter):
        mid = (low + high) / 2.0
        v_mod[target_idx] = mid
        alloc_mid, model_mid = check_and_plot(x, y, v_mod, target_idx, allocation_rule, c=c,
                                              loss = 'hinge', plot = plot, fit_intercept=fit_intercept)

        if alloc_mid == 1:
            high = mid
        else:
            low = mid

        if high - low < tol and alloc_mid == 1:
            critical_v = (low + high) / 2.0
            break

    
    # if i == max_iter - 1:
    #     print("Warning: reached max iterations")
    if alloc_mid == 0 or critical_v is None:
        # print("Warning: alloc_mid is still 0 or critical_v is none")
        # print(f"alloc mid =", alloc_mid)
        # print(f"high, low", high, low)
        # print(f"critical_v =", critical_v)
        # print(f"iter =", i)
        alloc_mid = 1
        critical_v = (high + low) // 2 # v[target_idx]
 
    return critical_v, alloc_mid, i+1 


def plot_svm_decision_boundary(clf, X, y, v=None, target_idx=None, title="SVC Decision Boundary", labels = None):
    if X.shape[1] == 1:
        # shape (n,1) treat as 1D data
        plot_svm_decision_boundary_1d(clf, X[:, 0], y, v=v, target_idx=target_idx, title=title) #, labels= labels)
    elif X.shape[1] == 2:
        # shape (n,2)
        plot_svm_decision_boundary_2d(clf, X, y, v=v, target_idx=target_idx, title=title, labels= labels)
    else:
        raise ValueError(f"Only 1D or 2D features supported. Got shape {X.shape}")

def plot_svm_decision_boundary_1d(clf, X, y, v=None, target_idx=None, title="SVC Decision Boundary"):
    """
    Plot decision boundary for an SVM trained on 1D data.

    Parameters:
    - clf: trained sklearn.svm.SVC model
    - X: 1D features (n_samples,) or (n_samples, 1)
    - y: labels (+1, -1), shape (n_samples,)
    - v: optional weights for marker sizes
    - target_idx: optional index of a point to highlight
    - title: plot title
    """
    # Flatten X to 1D if needed
    X = np.asarray(X).reshape(-1)
    y = np.asarray(y)

    colors = ['tab:blue', 'tab:orange']
    cmap = ListedColormap(colors)
    light_colors = [to_rgba(c, alpha=0.15) for c in colors]

    unique_labels = np.unique(y)
    if len(unique_labels) != 2:
        raise ValueError("This function supports only binary classification.")

    plt.figure(figsize=(8, 2.5))

    # Plot points on x-axis (y=0)
    if v is not None:
        # print("v is not none")
        plt.scatter(X, np.zeros_like(X), c=y, s=v * 20, cmap=cmap, edgecolors='k', zorder=3)
    else:
        plt.scatter(X, np.zeros_like(X), c=y, cmap=cmap, edgecolors='k', zorder=3)

    # Annotate points by index
    for i, x_val in enumerate(X):
        plt.text(x_val, 0.05, str(i), ha='center', fontsize=9, color='black', zorder=4)

    # Highlight target point if specified
    if target_idx is not None:
        plt.scatter(X[target_idx], 0, s=150, facecolors='none', edgecolors='red', linewidths=2, zorder=5)

    ax = plt.gca()
    xlim = ax.get_xlim()

    # Create dense grid over x-axis
    xx = np.linspace(xlim[0], xlim[1], 500).reshape(-1, 1)
    Z_decision = clf.decision_function(xx)
    decision_boundary_x = xx[np.argmin(np.abs(Z_decision))][0]

    # Draw decision boundary
    ax.axvline(decision_boundary_x, color='k', linestyle='-')

    # Shade left/right classification regions
    ylim = ax.get_ylim()
    ax.fill_betweenx(ylim, xlim[0], decision_boundary_x, color=colors[0], alpha=0.15)
    ax.fill_betweenx(ylim, decision_boundary_x, xlim[1], color=colors[1], alpha=0.15)

    # Format plot
    plt.yticks([])  # Hide y-axis ticks
    plt.xlabel('Feature 1')
    plt.title(title)
    plt.grid(axis='x')

    # Legend
    legend_handles = [mpatches.Patch(color=colors[i], label=f'Label {int(lbl)}') for i, lbl in
                      enumerate(unique_labels)]
    legend_handles.append(plt.Line2D([0], [0], marker='o', color='w', label='Target',
                                     markerfacecolor='none', markeredgecolor='r', markersize=10, linewidth=2))
    plt.legend(handles=legend_handles, loc='upper right')

    plt.tight_layout()
    # plt.savefig('1d_plot.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_svm_decision_boundary_2d(clf, X, y, v=None, target_idx=None, title="SVC Decision Boundary", labels = None):
    """
    Plots the decision boundary, margins, and support vectors for a 2D SVC.

    Parameters:
    - clf: a trained sklearn.svm.SVC model
    - X: feature matrix (n_samples, 2)
    - y: labels (n_samples,)
    - v: optional weights for data points, used to scale marker size
    - target_idx: index of the data point to highlight
    - title: plot title
    """
    # Define two distinct colors for binary classification
    colors = ['tab:blue', 'tab:orange']
    cmap = ListedColormap(colors)
    light_colors = [to_rgba(col, alpha=0.15) for col in colors]
    light_cmap = ListedColormap(light_colors)

    unique_labels = np.unique(y)
    if len(unique_labels) != 2:
        raise ValueError("This function only supports binary classification.")

    plt.figure(figsize=(8, 6))

    # Plot all data points
    if v is not None:
        plt.scatter(X[:, 0], X[:, 1], c=y, s=v * 10, cmap=cmap, edgecolors='k')
    else:
        plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors='k')

    # Annotate each point with its index
    if labels is not None and len(labels) > 0:
        for i, (x0, x1) in enumerate(X):
            plt.text(x0 + 0.02, x1 + 0.02, str(labels[i]), fontsize=9, color='black')
    else:
        for i, (x0, x1) in enumerate(X):
            plt.text(x0 + 0.02, x1 + 0.02, str(i), fontsize=9, color='black')

    # Highlight the target point
    if target_idx is not None:
        plt.scatter(X[target_idx, 0], X[target_idx, 1], s=100, facecolors='none', edgecolors='red', linewidths=2)

    # Background classification regions
    ax = plt.gca()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()

    xx, yy = np.meshgrid(np.linspace(xlim[0], xlim[1], 500),
                         np.linspace(ylim[0], ylim[1], 500))
    grid = np.c_[xx.ravel(), yy.ravel()]
    Z_labels = clf.predict(grid).reshape(xx.shape)
    ax.contourf(xx, yy, Z_labels, alpha=0.15, cmap=light_cmap)

    # Decision boundary and margins
    Z = clf.decision_function(grid).reshape(xx.shape)
    ax.contour(xx, yy, Z, colors='k', levels=[-1, 0, 1],
               linestyles=['--', '-', '--'], linewidths=1.5)

    # # Plot support vectors
    # plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],
    #             facecolors='none', edgecolors='pink', linewidths=1.5, label='Support Vectors')

    # Custom legend
    handles = [
        plt.Line2D([0], [0], marker='o', color='w', label=f'Label {int(lbl)}',
                markerfacecolor=col, markersize=10, markeredgecolor='k')
        for lbl, col in zip(unique_labels, colors)
    ]
    handles.append(
        plt.Line2D([0], [0], marker='o', color='w', label='Target',
                   markerfacecolor='none', markeredgecolor='r', markersize=10, linewidth=2)
    )
    # handles.append(
    #     plt.Line2D([0], [0], marker='o', color='w', label='Support Vectors',
    #                markerfacecolor='none', markeredgecolor='pink', markersize=10, linewidth=1.5)
    # )

    num = random.randint(1, 1000)
    plt.legend(handles=handles)
    plt.xlabel('Feature 1')
    plt.ylabel('Feature 2')
    plt.title(title)
    plt.grid(True)
    # plt.savefig(f'2d_plot_{num}.png', dpi=300, bbox_inches='tight')
    plt.show()