import numpy as np
import gurobipy as gp

def gaussian_kernel(x1, x2, sigma=1.0):
    """Compute the Gaussian (RBF) kernel between two numpy arrays."""
    return np.exp(-np.linalg.norm(x1 - x2, axis=1) ** 2 / (2 * sigma ** 2))

def compute_kernel_matrix(X, Z, kernel_function):
    """Compute the kernel matrix for datasets X and Z using the specified kernel function."""
    n = X.shape[0]
    m = Z.shape[0]
    K = np.zeros((n, m))
    for i in range(n):
        K[i, :] = kernel_function(X[i, :], Z)
    return K

def kernel_mean_matching(source, target, eps = 1, kernel_function=gaussian_kernel, lambda_reg=0.1, sigma=1.0, weights_ub = 1000):
    n_s = source.shape[0]
    n_t = target.shape[0]
    # print(f"n_s: {n_s}, n_t: {n_t}")
    # Compute kernel matrices
    if kernel_function == gaussian_kernel:
        print("Computing Source to Source Gaussian kernel matrix")
        K_ss = compute_kernel_matrix(source, source, lambda x, y: kernel_function(x, y, sigma))
        print("Computing Source to Target Gaussian kernel matrix")
        K_st = compute_kernel_matrix(source, target, lambda x, y: kernel_function(x, y, sigma))
    
    # Gurobi model setup
    model = gp.Model("ImportanceSampling")
    # Set the OutputFlag to 0 to turn off solver output
    model.setParam('OutputFlag', 0)
    weights = model.addMVar(shape=n_s, lb=0, ub=weights_ub, name="weights")
    
    # Objective function
    obj = weights @ K_ss @ weights - 2 * weights @ K_st @ np.ones(n_t) * n_s / n_t + lambda_reg * weights @ weights
    model.setObjective(obj, gp.GRB.MINIMIZE)
    
    # Constraint: |\sum_i w_i - m | <= eps* m
    model.addConstr(np.ones(n_s) @ weights -n_s <= eps*n_s)
    model.addConstr(np.ones(n_s) @ weights -n_s >= -eps*n_s)
    print(f"Optimizing Kernel Mean Matching...\n")
    model.optimize()
    
    if model.status == gp.GRB.OPTIMAL:
        return weights.X
    else:
        return None
