import time
import subprocess
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, ConstantKernel as C, WhiteKernel, RBF
from scipy.stats import norm
from scipy.optimize import minimize
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.exceptions import ConvergenceWarning
import warnings
from pgpr_ppitc import pgpr_ppitc_ls
from pgpr_util import *
from pgpr_type import *
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore", category=ConvergenceWarning)

def penalize_boundary_points(x, bounds, penalty=1e6, threshold=1e-3):
    for i, (lb, ub) in enumerate(bounds):
        if x[i] - lb < threshold or ub - x[i] < threshold:
            return penalty
    return 0.0

def penalize_repeated_points(x, evaluated_points, penalty=1e6):
    for point in evaluated_points:
        if np.allclose(x, point, atol=1e-9):
            return penalty
    return 0.0

def acquisition_function_wrapper_ucb(x, custom_gp, test_data_file, support_set_file, 
                                    global_cov_file, global_mean_file, bounds, kappa=2.0):
    x = np.atleast_2d(x)
    with open(test_data_file, "w") as f:
        line = f"0.000000 " + " ".join([f"{j}:{x[0, j]:.6f}" for j in range(len(x[0]))]) + "\n"
        f.write(line)
    custom_gp.pitc_regr_low_core(support_set_file, test_data_file, global_cov_file, global_mean_file)
    mu = custom_gp.pmu[0]
    sigma = np.sqrt(max(custom_gp.pvar[0], 1e-10))
    if np.isnan(mu) or np.isnan(sigma) or sigma < 1e-10:
        return 1e6
    ucb = mu + kappa * sigma
    boundary_penalty = 0.0
    threshold = 0.05
    for i, (lb, ub) in enumerate(bounds):
        rng = ub - lb
        if (x[0, i] - lb) < threshold * rng or (ub - x[0, i]) < threshold * rng:
            boundary_penalty += 0.1
    return -ucb + boundary_penalty

def optimize_acquisition_function_pgpr_ucb(custom_gp, bounds, test_data_file, support_set_file,
                                           global_cov_file, global_mean_file, restarts, iteration, best_x, kappa=2.0):
    lb = np.array([b[0] for b in bounds])
    ub = np.array([b[1] for b in bounds])
    if best_x is None:
        best_x = np.random.uniform(lb, ub)
    best_acq = np.inf
    best_x_found = None
    start_points = best_x + np.random.normal(0, 0.5, size=(restarts, len(bounds)))
    for x0 in start_points:
        for exploration_factor in [1.0]:
            x0_noisy = x0 + np.random.normal(0, exploration_factor, size=len(x0))
            x0_noisy = np.clip(x0_noisy, lb, ub)
            res = minimize(
                acquisition_function_wrapper_ucb,
                x0_noisy,
                args=(custom_gp, test_data_file, support_set_file, 
                    global_cov_file, global_mean_file, bounds, kappa),
                bounds=bounds,
                method='L-BFGS-B',
                options={'maxiter': 100, 'gtol': 1e-5}
            )
            if res.success and res.fun < best_acq:
                best_acq = res.fun
                best_x_found = res.x
    if best_x_found is None:
        best_x_found = np.random.uniform(lb, ub)
    return best_x_found

def acquisition_function_wrapper_ei(x, custom_gp, best_y, test_data_file, support_set_file, 
                                   global_cov_file, global_mean_file, bounds, xi=0.01):
    x = np.atleast_2d(x)
    with open(test_data_file, "w") as f:
        line = f"0.000000 " + " ".join([f"{j}:{x[0, j]:.6f}" for j in range(len(x[0]))]) + "\n"
        f.write(line)
    custom_gp.pitc_regr_low_core(support_set_file, test_data_file, global_cov_file, global_mean_file)
    mu = custom_gp.pmu[0]
    sigma = np.sqrt(max(custom_gp.pvar[0], 1e-10))
    if np.isnan(mu) or np.isnan(sigma) or sigma < 1e-10:
        return 1e6
    improvement = best_y - mu - xi
    Z = improvement / sigma
    Z = np.clip(Z, -8.0, 8.0)
    cdf = norm.cdf(Z)
    pdf = norm.pdf(Z)
    ei = improvement * cdf + sigma * pdf
    boundary_penalty = 0.0
    threshold = 0.05
    for i, (lb, ub) in enumerate(bounds):
        rng = ub - lb
        if (x[0, i] - lb) < threshold * rng or (ub - x[0, i]) < threshold * rng:
            boundary_penalty += 0.1
    return -max(ei, 0) + boundary_penalty

def optimize_acquisition_function_pgpr_ei(custom_gp, best_y, bounds, test_data_file, support_set_file,
                                          global_cov_file, global_mean_file, restarts, iteration, best_x):
    lb = np.array([b[0] for b in bounds])
    ub = np.array([b[1] for b in bounds])
    if best_x is None:
        best_x = np.random.uniform(lb, ub)
    best_acq = np.inf
    best_x_found = None
    start_points = best_x + np.random.normal(0, 1.0, size=(restarts, len(bounds)))
    for x0 in start_points:
        for exploration_factor in [0.5, 1.0]:
            x0_noisy = x0 + np.random.normal(0, exploration_factor, size=len(x0))
            res = minimize(
                acquisition_function_wrapper_ei,
                x0_noisy,
                args=(custom_gp, best_y, test_data_file, support_set_file, 
                      global_cov_file, global_mean_file, bounds),
                bounds=bounds,
                method='L-BFGS-B',
                options={'maxiter': 100, 'gtol': 1e-5}
            )
            if res.success and res.fun < best_acq:
                best_acq = res.fun
                best_x_found = res.x
    if best_x_found is None:
        best_x_found = np.random.uniform(lb, ub)
    return best_x_found

def latin_hypercube(n, d, lower, upper):
    lhd = np.zeros((n, d))
    for i in range(d):
        perm = np.random.permutation(n)
        lhd[:, i] = (perm + np.random.rand(n)) / n
        lhd[:, i] = lower[i] + (upper[i] - lower[i]) * lhd[:, i]
    return lhd

def propose_with_exact_gp(X_hist, y_hist, bounds, restarts, best_x, xi=0.01):
    kernel = C(1.0, (1e-3, 1e3)) * RBF([1.0, 1.0], (1e-2, 1e2))
    gp = GaussianProcessRegressor(kernel=kernel, alpha=1e-3, normalize_y=True, n_restarts_optimizer=restarts)
    gp.fit(X_hist, y_hist)
    lb = np.array([b[0] for b in bounds])
    ub = np.array([b[1] for b in bounds])
    best_acq = -np.inf
    best_x_found = None
    best_y = np.max(y_hist)
    start_points = best_x + np.random.normal(0, 1.0, size=(restarts, len(bounds)))
    for x0 in start_points:
        x0_noisy = x0 + np.random.normal(0, 1.0, size=len(x0))
        x0_noisy = np.clip(x0_noisy, lb, ub)
        x0_noisy = x0_noisy.reshape(1, -1)
        mu, sigma = gp.predict(x0_noisy, return_std=True)
        sigma = max(sigma[0], 1e-10)
        improvement = mu[0] - best_y - xi
        Z = improvement / sigma
        ei = improvement * norm.cdf(Z) + sigma * norm.pdf(Z)
        if ei > best_acq:
            best_acq = ei
            best_x_found = x0_noisy[0]
    if best_x_found is None:
        best_x_found = np.random.uniform(lb, ub)
    return best_x_found

def svm_loss(params, client_data):
    gamma, C = params
    losses = []
    for X, y in client_data:
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
        clf = SVC(kernel='rbf', gamma=gamma, C=C)
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_val)
        acc = accuracy_score(y_val, y_pred)
        losses.append(1 - acc)
    return np.mean(losses)

def bayesian_optimization_pgpr(max_iter,
                               init_samples,
                               ei_restarts,
                               track_points=None,
                               track_best_values=None,
                               p_seq = None):
    np.random.seed(37)
    bounds = [(0.01, 10.0), (1e-4, 10.0)]
    num_clients = 3
    max_support_points = 8

    for client_id in range(0, num_clients):
        client_hyp_file = f"./data/landmines_snap/client_{client_id}.hyp"
        with open(client_hyp_file, "w") as f:
            f.write("0.5 0.03 0.0 2 10.0 10.0\n")

    client_data = []
    for client_id in range(0, num_clients):
        data = np.loadtxt(f"./data/landmines_snap/client_{client_id}")
        X_client = data[:, :-1]
        y_client = data[:, -1]
        client_data.append((X_client, y_client))

    best_x_clients = []
    best_y_clients = []
    lower = np.array([b[0] for b in bounds])
    upper = np.array([b[1] for b in bounds])
    X_hyper = latin_hypercube(init_samples, 2, lower, upper)
    for client_id in range(0, num_clients):
        X_client, y_client = client_data[client_id]
        y_acc = []
        for params in X_hyper:
            gamma, C = params
            X_train, X_val, y_train, y_val = train_test_split(X_client, y_client, test_size=0.3, random_state=42)
            clf = SVC(kernel='rbf', gamma=gamma, C=C, max_iter=1)
            clf.fit(X_train, y_train)
            y_pred = clf.predict(X_val)
            acc = accuracy_score(y_val, y_pred)
            y_acc.append(acc)
        y_acc = np.array(y_acc)
        train_data_file = f"./data/landmines_snap/client_{client_id}.trn"
        support_set_file = f"./data/landmines_snap/client_{client_id}.spt"
        with open(train_data_file, "w") as f:
            for i in range(len(y_acc)):
                line = f"{y_acc[i]:.6f} 0:{X_hyper[i, 0]:.6f} 1:{X_hyper[i, 1]:.6f}\n"
                f.write(line)
        with open(support_set_file, "w") as f:
            for i in range(len(y_acc)):
                line = f"{y_acc[i]:.6f} 0:{X_hyper[i, 0]:.6f} 1:{X_hyper[i, 1]:.6f}\n"
                f.write(line)
        best_idx = np.argmin(y_acc)
        best_x_clients.append(X_hyper[best_idx])
        best_y_clients.append(y_acc[best_idx])

    for client_id in range(0, num_clients):
        train_data_file = f"./data/landmines_snap/client_{client_id}.trn"
        support_set_file = f"./data/landmines_snap/client_{client_id}.spt"
        local_mean_file = f"temp/client_{client_id}.mu.{client_id}"
        local_cov_file = f"temp/client_{client_id}.cov.{client_id}"
        client_hyp_file = f"./data/landmines_snap/client_{client_id}.hyp"
        custom_gp = pgpr_ppitc_ls(hypf=client_hyp_file)
        custom_gp.optimize_hyperparameters_init(train_data_file, support_set_file)
        custom_gp.regress_local(train_data_file, support_set_file, local_mean_file, local_cov_file)

    best_acc_history = [[] for _ in range(num_clients)]
    if p_seq is None:
        p_seq = np.linspace(0, 1, max_iter+1)[1:]
    for it in range(1, max_iter + 1):
        n_lines = None
        global_mean = None
        global_cov = None
        serv_outf = f"temp/client_0"
        support_set_file = f"./data/landmines_snap/client_0.spt"
        with open(support_set_file, "r") as f:
            current_support_size = sum(1 for _ in f)
        p = p_seq[it-1]
        rand_val = np.random.uniform(0, 1)
        if rand_val > p:

            server_cmd = [
                "./ServerTest", 
                "-role", "server",
                "-cov", str(current_support_size),
                "-mu", str(current_support_size),
                "-mufile", serv_outf,
                "-covfile", serv_outf,
                "-clients", str(num_clients)
            ]
            server_proc = subprocess.Popen(server_cmd)
            time.sleep(1)
            client_procs = []
            port_start = 8080
            for client_id in range(1, num_clients):
                outf_file = f"temp/client_{client_id}"
                port = port_start + client_id - 1
                client_cmd = [
                    "./ClientTest",
                    "-role", f"client{client_id}",
                    "-port", str(port),
                    "-cov", str(current_support_size),
                    "-mu", str(current_support_size),
                    "-mufile", outf_file,
                    "-covfile", outf_file
                ]
                proc = subprocess.Popen(client_cmd)
                client_procs.append(proc)
            time.sleep(1)
            for proc in client_procs:
                proc.wait()
            server_proc.wait()
        global_mean_file = "temp/client_0.gmu"
        global_cov_file = "temp/client_0.gcov"
        for client_id in range(0, num_clients):
            train_data_file = f"./data/landmines_snap/client_{client_id}.trn"
            support_set_file = f"./data/landmines_snap/client_{client_id}.spt"
            local_mean_file = f"temp/client_{client_id}.mu.{client_id}"
            local_cov_file = f"temp/client_{client_id}.cov.{client_id}"
            client_hyp_file = f"./data/landmines_snap/client_{client_id}.hyp"
            test_data_file = f"./data/landmines_snap/client_{client_id}.tst"
            custom_gp = pgpr_ppitc_ls(hypf=client_hyp_file)
            X_hist = []
            y_hist = []
            with open(train_data_file, "r") as f:
                for line in f:
                    parts = line.strip().split()
                    y_hist.append(float(parts[0]))
                    X_hist.append([float(parts[1].split(":")[1]), float(parts[2].split(":")[1])])
            X_hist = np.array(X_hist)
            y_hist = np.array(y_hist)
            best_x = best_x_clients[client_id]
            best_y = best_y_clients[client_id]
            if rand_val > p:
                x_next = optimize_acquisition_function_pgpr_ei(
                    custom_gp, best_y, bounds, test_data_file, support_set_file,
                    global_cov_file, global_mean_file, ei_restarts, iteration=it, best_x=best_x
                )
            else:
                x_next = propose_with_exact_gp(X_hist, y_hist, bounds, ei_restarts, best_x)
            X_client, y_client = client_data[client_id]
            gamma, C = x_next
            X_train, X_val, y_train, y_val = train_test_split(X_client, y_client, test_size=0.3, random_state=42)
            clf = SVC(kernel='rbf', gamma=gamma, C=C, max_iter=8)
            clf.fit(X_train, y_train)
            y_pred = clf.predict(X_val)
            acc = accuracy_score(y_val, y_pred)
            y_next = acc
            X_hist = np.vstack([X_hist, x_next])
            y_hist = np.append(y_hist, y_next)
            with open(train_data_file, "w") as f:
                for i in range(len(y_hist)):
                    line = f"{y_hist[i]:.6f} 0:{X_hist[i, 0]:.6f} 1:{X_hist[i, 1]:.6f}\n"
                    f.write(line)
            best_indices = np.argsort(-y_hist)[:max_support_points]
            X_support = X_hist[best_indices]
            y_support = y_hist[best_indices]
            with open(support_set_file, "w") as f:
                for i in range(len(y_support)):
                    line = f"{y_support[i]:.6f} 0:{X_support[i, 0]:.6f} 1:{X_support[i, 1]:.6f}\n"
                    f.write(line)
            custom_gp.regress_local(train_data_file, support_set_file, local_mean_file, local_cov_file)
            if y_next > best_y:
                best_x_clients[client_id] = x_next
                best_y_clients[client_id] = y_next
            best_acc_history[client_id].append(best_y_clients[client_id])
        for client_id in range(0, num_clients):
            if len(best_acc_history[client_id]) < it:
                best_acc_history[client_id].append(best_y_clients[client_id])
    update_counts = [0] * num_clients
    for client_id in range(num_clients):
        accs = best_acc_history[client_id]
        update_counts[client_id] = sum([accs[i] > accs[i-1] for i in range(1, len(accs))])
    most_updated_client = np.argmax(update_counts)
    mean_loss_per_round = []
    num_rounds = len(best_acc_history[0])
    for it in range(num_rounds):
        accs_this_round = [best_acc_history[c][it] for c in range(num_clients)]
        mean_loss = 1 - np.mean(accs_this_round)
        mean_loss_per_round.append(mean_loss)
    plt.figure(figsize=(10, 5))
    plt.plot(mean_loss_per_round, label='dGP-FBO', color='red', marker='x')
    plt.xlabel('BO Round')
    plt.ylabel('Mean Loss')
    plt.title('Mean Loss Across All Clients Over BO Rounds')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    return best_y_clients, best_x_clients

if __name__ == "__main__":
    bo_iter = 10
    t_seq = np.arange(1, bo_iter + 1)
    p_seq = 1 - 1 / np.sqrt(t_seq)
    explored_points = []
    best_values_custom = []
    def track_points(x):
        explored_points.append(x)
    best_vals, best_pts = bayesian_optimization_pgpr(
        max_iter=bo_iter,
        init_samples=4,
        ei_restarts=8,
        track_points=track_points,
        track_best_values=best_values_custom,
        p_seq=p_seq
    )
    for client_id, (val, pt) in enumerate(zip(best_vals, best_pts), 1):
        print(f"Client {client_id}: Best Value: {val:.6f}, Best Hyperparameters: {pt}")
    explored_points = np.atleast_2d(explored_points)