import numpy as np
import random
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml


# --- Agent Classes ---

class SVMAgent:
    def __init__(self, num_param, id=0, lr=1e-3, D=1, NUM_AGENTS=16):
        self.id = id
        self.lr = lr
        self.D = D
        self.NUM_AGENTS = NUM_AGENTS
        self.num_param = num_param
        self.weight = np.zeros(num_param)
        self.action = np.zeros(num_param)
        self.grad_points = np.zeros(num_param)

    def initialize_action(self):
        self.action = np.zeros(self.num_param)

    def get_action(self):
        return self.action

    def get_weight(self):
        return self.weight

    def get_grad_points(self):
        return self.grad_points

    def get_grad_point(self):
        s = np.random.rand(1)
        grad_point = self.weight + s * self.action
        self.grad_points = grad_point.copy()

    def DOC2S_get_new_weight(self):
        return self.weight + self.NUM_AGENTS * self.action

    def get_new_weight(self):
        return self.weight + self.action

    def _project_action(self):
        action_norm = np.linalg.norm(self.action)
        if action_norm > self.D:
            self.action = self.action * self.D / action_norm

    def set_weight(self, x):
        self.weight = x.copy()

    def set_action(self, x):
        self.action = x.copy()

    def action_grad_update(self, grad):
        self.action = self.action - self.lr * grad
        self._project_action()


class MLPAgent(SVMAgent):
    def __init__(self, input_dim, hidden_dim, id=0, lr=1e-3, D=1, NUM_AGENTS=16):
        # MLP Model placeholder to get param size
        temp_model = MLP(input_dim, hidden_dim)
        num_param = len(temp_model.flatten_params())
        super().__init__(num_param, id, lr, D, NUM_AGENTS)

        # Initialize weights randomly
        self.weight = temp_model.flatten_params().copy()
        self.model_struct = temp_model  # Keep structure for verifying if needed


# --- Model Classes (for MLP) ---

class MLP:
    def __init__(self, input_dim, hidden_dim):
        # Xavier initialization
        self.W1 = np.random.randn(input_dim, hidden_dim) * np.sqrt(2. / input_dim)
        self.b1 = np.zeros(hidden_dim)
        self.W2 = np.random.randn(hidden_dim, 10)
        self.b2 = np.zeros(10)

    def forward(self, X):
        self.z1 = X.dot(self.W1) + self.b1
        self.a1 = np.maximum(0, self.z1)
        self.z2 = self.a1.dot(self.W2) + self.b2
        return self.z2

    def flatten_params(self):
        return np.concatenate([self.W1.flatten(), self.b1, self.W2.flatten(), self.b2])

    def unflatten_params(self, params):
        W1_size = self.W1.size
        b1_size = self.b1.size
        W2_size = self.W2.size
        b2_size = self.b2.size

        self.W1 = params[:W1_size].reshape(self.W1.shape)
        self.b1 = params[W1_size:W1_size + b1_size]
        self.W2 = params[W1_size + b1_size:W1_size + b1_size + W2_size].reshape(self.W2.shape)
        self.b2 = params[W1_size + b1_size + W2_size:]


# --- Network Class ---

class FastComNetwork:
    def __init__(self, W=None):
        self.W = W
        if W is not None:
            self.N = W.shape[0]

    def set_weight(self, W):
        self.W = W
        self.N = W.shape[0]

    def propagate_weights(self, agents, R):
        if self.W is None: self.set_weight(np.eye(len(agents)))
        weight_data = [agent.get_weight() for agent in agents]
        mixed_weight = self.propagate_data(weight_data, R)
        for k in range(self.N):
            agents[k].set_weight(mixed_weight[k])

    def propagate_actions(self, agents, R):
        if self.W is None: self.set_weight(np.eye(len(agents)))
        action_data = [agent.get_action() for agent in agents]
        mixed_action = self.propagate_data(action_data, R)
        for k in range(self.N):
            agents[k].set_action(mixed_action[k])

    def get_average_weight(self, agents):
        weight_data = [agent.get_weight() for agent in agents]
        return self.get_average(weight_data)

    def calculate_consensus_error(self, agents):
        """Calculates 1/N * sum ||w_i - w_bar||"""
        weight_data = [agent.get_weight() for agent in agents]
        avg_weight = self.get_average(weight_data)
        total_norm = 0
        for w in weight_data:
            total_norm += np.linalg.norm(w - avg_weight)
        return total_norm / len(agents)

    def get_average(self, data):
        N = len(data)
        if isinstance(data[0], np.ndarray):
            avg_data = np.zeros_like(data[0])
        else:
            avg_data = 0
        for k in range(N):
            avg_data += data[k] / N
        return avg_data

    def propagate_data(self, data, R):
        eigenvalues = np.linalg.eigvals(self.W)
        eigenvalues = np.sort(np.abs(eigenvalues))[::-1]
        beta = eigenvalues[1] if len(eigenvalues) > 1 else 0
        if beta >= 1.0: beta = 0.999

        eta = (1 - np.sqrt(1 - beta ** 2)) / (1 + np.sqrt(1 - beta ** 2))

        N = len(data)
        Z = [d.copy() for d in data]
        Z_prev = [d.copy() for d in data]

        for r in range(1, R + 1):
            Z_new = [None] * N
            for m in range(N):
                # Z_m_r = W[m, :] * Z
                Z_m_r = np.zeros_like(Z[0])
                # Optimized sparse multiplication if needed, but dense here is fine for N=16
                for k in range(N):
                    if self.W[m, k] != 0:
                        Z_m_r += self.W[m, k] * Z[k]
                Z_new[m] = (1 + eta) * Z_m_r - eta * Z_prev[m]
            Z_prev = [z.copy() for z in Z]
            Z = Z_new
        return Z


# --- Oracle Classes ---

class SVMOracle:
    def __init__(self, alpha=2, lam=1e-5):
        self.alpha = alpha
        self.lam = lam

    def get_gradients(self, w, x, y):
        # Hinge loss subgradient: -y*x if 1-y(wx)>0
        margin = 1 - y * (x @ w)
        indicator = (margin > 0).astype(float)
        grad_loss = - (x.T @ (y * indicator)) / x.shape[0]
        grad_reg = self.lam * np.sign(w) * (np.abs(w) < self.alpha)
        return grad_loss + grad_reg

    def get_fn_val(self, w, x, y):
        hinge = np.maximum(1 - y * (x @ w), 0).mean()
        reg = self.lam * np.clip(np.abs(w), 0, self.alpha).sum()
        return hinge + reg

    def get_zo_grad(self, w, x, y, delta=1e-3):
        w1 = np.random.randn(*w.shape)
        norm_w1 = np.linalg.norm(w1)
        if norm_w1 == 0: norm_w1 = 1
        w1 = w1 / norm_w1
        w2 = w + delta * w1
        w3 = w - delta * w1
        fn_diff = self.get_fn_val(w2, x, y) - self.get_fn_val(w3, x, y)
        grad_est = (w.shape[0] * fn_diff / (2 * delta)) * w1
        return grad_est


class MLPOracle:
    def __init__(self, lam=1e-5, hidden_dim=256):
        self.lam = lam
        self.hidden_dim = hidden_dim

    def softmax(self, x):
        e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
        return e_x / np.sum(e_x, axis=1, keepdims=True)

    def get_gradients(self, w, x, y):
        input_dim = x.shape[1]
        temp_model = MLP(input_dim, self.hidden_dim)
        temp_model.unflatten_params(w)
        logits = temp_model.forward(x)
        probs = self.softmax(logits)
        m = x.shape[0]
        y_one_hot = np.zeros((m, 10))
        for i in range(m): y_one_hot[i, int(y[i])] = 1

        d_logits = (probs - y_one_hot) / m
        dW2 = temp_model.a1.T.dot(d_logits)
        db2 = np.sum(d_logits, axis=0)
        d_a1 = d_logits.dot(temp_model.W2.T)
        d_z1 = d_a1 * (temp_model.z1 > 0)
        dW1 = x.T.dot(d_z1)
        db1 = np.sum(d_z1, axis=0)
        dW2 += self.lam * temp_model.W2
        dW1 += self.lam * temp_model.W1
        return np.concatenate([dW1.flatten(), db1, dW2.flatten(), db2])

    def get_zo_grad(self, w, x, y, delta=1e-3):
        w1 = np.random.randn(*w.shape)
        norm_w1 = np.linalg.norm(w1)
        if norm_w1 == 0: norm_w1 = 1
        w1 = w1 / norm_w1
        w2 = w + delta * w1
        w3 = w - delta * w1
        fn_diff = self.get_fn_val(w2, x, y) - self.get_fn_val(w3, x, y)
        grad_est = (w.shape[0] * fn_diff / (2 * delta)) * w1
        return grad_est

    def get_fn_val(self, w, x, y):
        input_dim = x.shape[1]
        temp_model = MLP(input_dim, self.hidden_dim)
        temp_model.unflatten_params(w)
        logits = temp_model.forward(x)
        probs = self.softmax(logits)
        m = x.shape[0]
        log_probs = -np.log(probs[np.arange(m), y.astype(int)] + 1e-8)
        loss = np.sum(log_probs) / m
        reg = 0.5 * self.lam * (np.sum(temp_model.W1 ** 2) + np.sum(temp_model.W2 ** 2))
        return loss + reg


# --- Dataset Class ---

class DatasetModel:
    def __init__(self, dsname=None, num_agent=3, mb_size=1, max_sample=10000):
        self.dsname = dsname
        self.num_agent = num_agent
        self.mb_size = mb_size
        self.max_sample = max_sample
        self.loaddataset()
        self.input_dim = self.X.shape[1]
        self.dssize = self.X.shape[0]
        shuffled_idx = np.random.permutation(self.dssize)
        sample_per_agent = self.dssize // self.num_agent
        self.agent_dict = {m: shuffled_idx[m * sample_per_agent:(m + 1) * sample_per_agent].tolist() for m in
                           range(num_agent)}

    def loaddataset(self):
        if self.dsname == 'mnist':
            X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False, parser='auto')
            y = y.astype(int)
        elif self.dsname == 'a9a':
            X, y = fetch_openml('a9a', version=1, return_X_y=True, as_frame=False, parser='auto')
            y = y.astype(int)
            y = np.where(y == -1, -1, 1)
        else:
            raise ValueError("Unknown dataset")

        # --- FIX: Convert sparse matrix to dense array if needed ---
        if hasattr(X, 'toarray'):
            X = X.toarray()

        # Ensure float32 for consistent calculations
        X = X.astype(np.float32)

        if X.shape[0] > self.max_sample:
            idx = np.random.choice(X.shape[0], self.max_sample, replace=False)
            X = X[idx];
            y = y[idx]

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

        # Normalize
        self.x_mean = X_train.mean(axis=0)
        self.x_std = X_train.std(axis=0)
        self.x_std[self.x_std < 1e-8] = 1.0
        X_train = (X_train - self.x_mean) / self.x_std
        X_test = (X_test - self.x_mean) / self.x_std

        self.X = X_train;
        self.y = y_train
        self.X_test = X_test;
        self.y_test = y_test

    def get_sample(self, agent=0):
        agent = agent % self.num_agent
        if len(self.agent_dict[agent]) < self.mb_size:
            idx = np.random.choice(self.dssize, self.mb_size)
        else:
            idx = random.sample(self.agent_dict[agent], k=self.mb_size)
        return self.X[idx, :], self.y[idx]

    def get_test_set(self):
        return self.X_test, self.y_test