import numpy as np
import random
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml

class SVMAgent:
    def __init__(self, num_param, id=0, lr=1e-3, D=1, NUM_AGENTS=32):
        self.id = id
        self.lr = lr
        self.D = D
        self.NUM_AGENTS = NUM_AGENTS
        self.num_param = num_param
        self.weight = np.zeros(num_param)
        self.action = np.zeros(num_param)
        self.grad_points = np.zeros(num_param)

    def initialize_action(self):
        self.action = np.zeros(self.num_param)

    def get_action(self):
        return self.action

    def get_weight(self):
        return self.weight

    def get_grad_points(self):
        return self.grad_points

    def get_grad_point(self):
        s = np.random.rand(1)
        grad_point = self.weight + s * self.action
        self.grad_points = grad_point.copy()

    def DOC2S_get_new_weight(self):
        return self.weight + self.NUM_AGENTS * self.action

    def get_new_weight(self):
        return self.weight + self.action

    def _project_action(self):
        action_norm = np.linalg.norm(self.action)
        if action_norm > self.D:
            self.action = self.action * self.D / action_norm

    def set_weight(self, x):
        self.weight = x.copy()

    def set_action(self, x):
        self.action = x.copy()

    def action_grad_update(self, grad):
        self.action = self.action - self.lr * grad
        self._project_action()


class FastComNetwork:
    def __init__(self, W=None):
        self.W = W
        if W is not None:
            self.N = W.shape[0]

    def set_weight(self, W):
        self.W = W
        self.N = W.shape[0]

    def propagate_weights(self, agents, R):
        if self.W is None:
            self.set_weight(np.eye(len(agents)))
        weight_data = [agent.get_weight() for agent in agents]
        mixed_weight = self.propagate_data(weight_data, R)
        for k in range(self.N):
            agents[k].set_weight(mixed_weight[k])

    def propagate_actions(self, agents, R):
        if self.W is None:
            self.set_weight(np.eye(len(agents)))
        action_data = [agent.get_action() for agent in agents]
        mixed_action = self.propagate_data(action_data, R)
        for k in range(self.N):
            agents[k].set_action(mixed_action[k])

    def get_average_weight(self, agents):
        weight_data = [agent.get_weight() for agent in agents]
        return self.get_average(weight_data)

    def get_average(self, data):
        N = len(data)
        num_param = data[0].shape
        avg_data = np.zeros(num_param)
        for k in range(N):
            avg_data += data[k] / N
        return avg_data

    def propagate_data(self, data, R):
        eigenvalues = np.linalg.eigvals(self.W)
        eigenvalues = np.sort(eigenvalues)[::-1]
        beta = eigenvalues[1]
        eta = (1 - np.sqrt(1 - beta ** 2)) / (1 + np.sqrt(1 - beta ** 2))
        N = len(data)
        num_param = data[0].shape
        Z = [d.copy() for d in data]
        Z_prev = [d.copy() for d in data]

        for r in range(1, R + 1):
            Z_new = [None] * N
            for m in range(N):
                Z_m_r = np.zeros(num_param)
                for k in range(N):
                    Z_m_r += self.W[m, k] * Z[k]
                Z_new[m] = (1 + eta) * Z_m_r - eta * Z_prev[m]
            Z_prev = Z.copy()
            Z = Z_new
        return Z


class ComNetwork:
    def __init__(self, W=None):
        self.W = W
        if W is not None:
            self.N = W.shape[0]

    def set_weight(self, W):
        self.W = W
        self.N = W.shape[0]

    def propagate_weights(self, agents):
        if self.W is None:
            self.set_weight(np.eye(len(agents)))
        weight_data = [agent.get_weight() for agent in agents]
        mixed_weight = self.propagate_data(weight_data)
        for k in range(self.N):
            agents[k].set_weight(mixed_weight[k])

    def propagate_actions(self, agents):
        if self.W is None:
            self.set_weight(np.eye(len(agents)))
        action_data = [agent.get_action() for agent in agents]
        mixed_action = self.propagate_data(action_data)
        for k in range(self.N):
            agents[k].set_action(mixed_action[k])

    def get_average_weight(self, agents):
        weight_data = [agent.get_weight() for agent in agents]
        return self.get_average(weight_data)

    def get_average(self, data):
        N = len(data)
        num_param = data[0].shape
        avg_data = np.zeros(num_param)
        for k in range(N):
            avg_data += data[k] / N
        return avg_data

    def propagate_data(self, data):
        mixed_data = []
        N = len(data)
        num_param = data[0].shape
        for m in range(N):
            data_dict = np.zeros(num_param)
            for k in range(N):
                data_dict += self.W[m, k] * data[k]
            mixed_data.append(data_dict)
        return mixed_data


class DatasetModel:
    def __init__(self, dsname=None, X=None, y=None, num_agent=3, mb_size=1, normalize=True, max_sample=10000):
        self.dsname = dsname
        self.normalize = normalize
        self.num_agent = num_agent
        self.mb_size = mb_size
        self.max_sample = max_sample

        if dsname is None:
            self.X = X
            self.y = y
        else:
            self.loaddataset()

        self.input_dim = self.X.shape[1]
        self.dssize = self.X.shape[0]

        shuffled_idx = np.random.permutation(self.dssize)
        sample_per_agent = self.dssize // self.num_agent
        self.agent_dict = {
            m: shuffled_idx[m * sample_per_agent:(m + 1) * sample_per_agent].tolist()
            for m in range(num_agent)
        }

    def loaddataset(self):
        if self.dsname == 'mnist':
            X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False, parser='auto')
            y = y.astype(int)
            mask = (y == 0) | (y == 1)
            X = X[mask]
            y = y[mask]

            if X.shape[0] > self.max_sample:
                idx = np.random.choice(X.shape[0], self.max_sample, replace=False)
                X = X[idx]
                y = y[idx]

            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=42
            )

            if self.normalize:
                self.x_mean = X_train.mean(axis=0)
                self.x_std = X_train.std(axis=0)
                self.x_std[self.x_std < 1e-8] = 1.0
                epsilon = 1e-8
                X_train = (X_train - self.x_mean) / (self.x_std + epsilon)
                X_test = (X_test - self.x_mean) / (self.x_std + epsilon)

            self.X = X_train
            self.y = y_train
            self.X_test = X_test
            self.y_test = y_test

        # 添加 Fashion-MNIST 支持
        elif self.dsname == 'fashion-mnist':
            # 加载 Fashion-MNIST 数据集
            from tensorflow.keras.datasets import fashion_mnist
            (X_train_full, y_train_full), (X_test_full, y_test_full) = fashion_mnist.load_data()

            # 重塑数据为 (样本数, 784)
            X_train_full = X_train_full.reshape(-1, 784).astype('float32')
            X_test_full = X_test_full.reshape(-1, 784).astype('float32')

            # 选择两个类别：例如 T恤/上衣(0) 和 裤子(1)
            class0, class1 = 0, 1
            train_mask = (y_train_full == class0) | (y_train_full == class1)
            test_mask = (y_test_full == class0) | (y_test_full == class1)

            X_train = X_train_full[train_mask]
            y_train = y_train_full[train_mask]
            X_test = X_test_full[test_mask]
            y_test = y_test_full[test_mask]

            # 将标签映射为 0 和 1
            y_train = np.where(y_train == class0, 0, 1)
            y_test = np.where(y_test == class0, 0, 1)

            # 限制样本数量
            if X_train.shape[0] > self.max_sample:
                idx = np.random.choice(X_train.shape[0], self.max_sample, replace=False)
                X_train = X_train[idx]
                y_train = y_train[idx]

            # 归一化
            if self.normalize:
                self.x_mean = X_train.mean(axis=0)
                self.x_std = X_train.std(axis=0)
                self.x_std[self.x_std < 1e-8] = 1.0
                epsilon = 1e-8
                X_train = (X_train - self.x_mean) / (self.x_std + epsilon)
                X_test = (X_test - self.x_mean) / (self.x_std + epsilon)

            self.X = X_train
            self.y = y_train
            self.X_test = X_test
            self.y_test = y_test

    def get_sample(self, agent=1):
        idx = random.sample(self.agent_dict[agent], k=self.mb_size)
        X_mb = self.X[idx, :]
        y_mb = self.y[idx]
        return X_mb, y_mb

    def get_test_set(self):
        return self.X_test, self.y_test


class MLPAgent:
    def __init__(self, input_dim, hidden_dim, id=0, lr=1e-3, D=1, NUM_AGENTS=32):
        self.id = id
        self.lr = lr
        self.D = D
        self.NUM_AGENTS = NUM_AGENTS
        self.model = MLP(input_dim, hidden_dim)
        self.num_param = len(self.model.flatten_params())
        self.weight = self.model.flatten_params().copy()
        self.action = np.zeros(self.num_param)
        self.grad_points = np.zeros(self.num_param)

    def initialize_action(self):
        self.action = np.zeros(self.num_param)

    def get_action(self):
        return self.action

    def get_weight(self):
        return self.weight

    def get_grad_points(self):
        return self.grad_points

    def get_grad_point(self):
        s = np.random.rand(1)
        grad_point = self.weight + s * self.action
        self.grad_points = grad_point.copy()

    def DOC2S_get_new_weight(self):
        return self.weight + self.NUM_AGENTS * self.action

    def get_new_weight(self):
        return self.weight + self.action

    def _project_action(self):
        action_norm = np.linalg.norm(self.action)
        if action_norm > self.D:
            self.action = self.action * self.D / action_norm

    def set_weight(self, flat_params):
        self.weight = flat_params.copy()
        self.model.unflatten_params(flat_params)

    def set_action(self, x):
        self.action = x.copy()

    def action_grad_update(self, grad):
        self.action = self.action - self.lr * grad
        self._project_action()


class MLP:
    def __init__(self, input_dim, hidden_dim):
        self.W1 = np.random.randn(input_dim, hidden_dim) * np.sqrt(2. / input_dim)
        self.b1 = np.zeros(hidden_dim)
        self.W2 = np.random.randn(hidden_dim, 1) * np.sqrt(2. / hidden_dim)
        self.b2 = np.zeros(1)

    def forward(self, X):
        self.z1 = X.dot(self.W1) + self.b1
        self.a1 = np.maximum(0, self.z1)
        self.z2 = self.a1.dot(self.W2) + self.b2

        # 确保输出形状为 (batch_size,)
        return self.z2.flatten()

    def flatten_params(self):
        return np.concatenate([self.W1.flatten(), self.b1, self.W2.flatten(), self.b2])

    def unflatten_params(self, params):
        W1_size = self.W1.size
        b1_size = self.b1.size
        W2_size = self.W2.size
        b2_size = self.b2.size

        self.W1 = params[:W1_size].reshape(self.W1.shape)
        self.b1 = params[W1_size:W1_size + b1_size]
        self.W2 = params[W1_size + b1_size:W1_size + b1_size + W2_size].reshape(self.W2.shape)
        self.b2 = params[W1_size + b1_size + W2_size:]


class MLPOracle:
    def __init__(self, lam=1e-5, hidden_dim=256):
        self.lam = lam
        self.hidden_dim = hidden_dim

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def get_gradients(self, w, x, y):
        input_dim = x.shape[1]
        temp_model = MLP(input_dim, self.hidden_dim)
        temp_model.unflatten_params(w)

        logits = temp_model.forward(x)
        probs = self.sigmoid(logits)

        if probs.shape != y.shape:
            raise ValueError(f"维度不匹配: probs {probs.shape} vs y {y.shape}")

        m = x.shape[0]
        d_logits = (probs - y) / m

        dW2 = temp_model.a1.T.dot(d_logits.reshape(-1, 1))
        db2 = np.sum(d_logits)
        d_a1 = d_logits.reshape(-1, 1).dot(temp_model.W2.T)
        d_z1 = d_a1 * (temp_model.z1 > 0)
        dW1 = x.T.dot(d_z1)
        db1 = np.sum(d_z1, axis=0)

        dW2 += self.lam * temp_model.W2
        dW1 += self.lam * temp_model.W1

        return np.concatenate([dW1.flatten(), db1, dW2.flatten(), np.array([db2])])

    def get_accuracy(self, w, x, y):
        input_dim = x.shape[1]
        temp_model = MLP(input_dim, self.hidden_dim)
        temp_model.unflatten_params(w)
        logits = temp_model.forward(x)

        if logits.ndim > 1:
            logits = logits.flatten()

        preds = (logits > 0).astype(int)

        if preds.shape != y.shape:
            raise ValueError(f"预测值和标签维度不匹配: preds {preds.shape} vs y {y.shape}")

        return np.mean(preds == y)

    def get_fn_val(self, w, x, y):
        """计算损失值（交叉熵）"""
        input_dim = x.shape[1]
        # 修正: 使用 self.hidden_dim 而不是硬编码的值
        temp_model = MLP(input_dim, self.hidden_dim)
        temp_model.unflatten_params(w)

        logits = temp_model.forward(x)
        probs = self.sigmoid(logits)

        if y.shape != probs.shape:
            y = y.flatten()

        # 交叉熵损失
        loss = -np.mean(y * np.log(probs + 1e-8) + (1 - y) * np.log(1 - probs + 1e-8))
        # L2 正则化
        reg = 0.5 * self.lam * (np.sum(temp_model.W1 ** 2) + np.sum(temp_model.W2 ** 2))

        return loss + reg


class DGFMAgent:
    def __init__(self, input_dim, hidden_dim, mb_size=16, id=0, lr=1e-3):
        self.lr = lr
        self.id = id
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        # 创建MLP模型
        self.model = MLP(input_dim, hidden_dim)
        self.num_param = len(self.model.flatten_params())
        self.weight = self.model.flatten_params().copy()

        self.y_grad = np.zeros(self.num_param)
        self.prev_grad = np.zeros(self.num_param)

    def get_y_grad(self):
        return self.y_grad

    def get_weight(self):
        return self.weight

    def update_y_grad(self, new_grad):
        self.y_grad = self.y_grad + new_grad - self.prev_grad
        self.prev_grad = new_grad

    def set_y_grad(self, y_grad):
        self.y_grad = y_grad

    def update_weight(self):
        self.weight = self.weight - self.lr * self.y_grad
        # 更新模型参数
        self.model.unflatten_params(self.weight)

    def set_weight(self, flat_params):
        self.weight = flat_params.copy()
        # 更新模型参数
        self.model.unflatten_params(flat_params)

class SVMOracle:
    def __init__(self, alpha=2, lam=1e-5):
        self.alpha = alpha
        self.lam = lam

    def get_gradients(self, w, x, y):
        dz1 = -(x * y[:, np.newaxis]) * (((x @ w) * y) < 1).astype(float)[:, np.newaxis]
        dz2 = self.lam * np.sign(w) * (np.abs(w) < self.alpha)
        dz1 = dz1.mean(axis=0)

        return dz1 + dz2  # consider taking mean

    def reg_term(self, w):
        return self.lam * np.clip(np.abs(w), 0, self.alpha).sum()

    def get_fn_val(self, w, x, y):
        z1 = np.maximum(1 - (x @ w) * y, 0).mean()
        z2 = self.lam * np.clip(np.abs(w), 0, self.alpha).sum()
        return z1 + z2

    def get_zo_grad(self, w, x, y, delta=1e-3):
        w1 = np.random.randn(*w.shape)
        w1 = w1 / np.linalg.norm(w1)
        w2 = w + delta * w1
        w3 = w - delta * w1
        fn_diff = self.get_fn_val(w2, x, y) - self.get_fn_val(w3, x, y)
        grad_est = (w1.shape[0] * fn_diff / (2 * delta)) * w1
        return grad_est

    def get_zo_grad_given_dir(self, w, w1, x, y, delta=1e-3):
        w2 = w + delta * w1
        w3 = w - delta * w1
        fn_diff = self.get_fn_val(w2, x, y) - self.get_fn_val(w3, x, y)
        grad_est = (w1.shape[0] * fn_diff / (2 * delta)) * w1
        return grad_est

    def get_zo_grad_dgfmp(self, w, w1, x, y, delta=1e-3):
        w2 = w + delta * w1
        w3 = w - delta * w1
        z0 = np.maximum(1 - (x * w2).sum(axis=1) * y, 0) + self.reg_term(w2)
        z1 = np.maximum(1 - (x * w3).sum(axis=1) * y, 0) + self.reg_term(w3)
        fn_diff = (z0 - z1)[:, np.newaxis]
        grad_est = (x.shape[1] * fn_diff / (2 * delta)) * w1
        return grad_est.mean(axis=0)

    def get_accuracy(self, w, x, y):
        pred = x @ w
        accuracy = ((pred * y) > 0).mean()
        return accuracy

class DGFMplus:
    def __init__(self, num_param, id=0, lr=1e-3, mb_size=16):
        self.lr = lr
        self.id = id
        self.num_param = num_param
        self.mb_size = mb_size

        self.weight = np.zeros(num_param)
        self.v = np.zeros(num_param)
        self.prev_grad = np.zeros(num_param)

    def set_v(self, v_new):
        self.v = v_new.copy()

    def get_v(self):
        return self.v.copy()

    def save_prev_grad(self):
        self.prev_grad = self.v.copy()

    def update_spider_grad(self, grad_new):
        self.v += grad_new - self.prev_grad  # v_i^k = v_i^{k-1} + g_i^k - g_i^{k-1}
        self.prev_grad = grad_new.copy()

    def update_weight(self):
        self.weight -= self.lr * self.v  # x_i^{k+1} = x_i^k - η v_i^k

    def get_weight(self):
        return self.weight.copy()

    def set_weight(self, w):
        self.weight = w.copy()